CF1163D(DP+KMP)
CF1163D(DP+KMP)
题意
给三个串 \(s,a,b\) 。最大化 \(f(s,a)-f(s,b)\) 。
\(f(s,a)\) 表示 \(a\) 在 \(s\) 中出现的次数(可重复)
\(|s| \le 1000, |a|,|b| \le 50\)
思路
解最优化问题,可以考虑 \(dp\) 。
对字符串做 \(dp\) ,我们可以根据匹配状态定义状态。
定义 \(dp[i][j][k]\) 表示考虑 \(s\) 前 \(i\) 个字符,分别后缀最大匹配了 \(a,b\) 的前缀状态是 \(j,k\) 时的答案。
那么 \(max(dp[|s|][j][k])\) 就是答案。
考虑状态 \((i-1,j,k)\) 的转移。当我们新加入一个字符 \(s_i\) 时,\(a,b\) 新的匹配状态设为 \(nj,nk\) 。此时需要考虑:
当母串新纳入一个字符后,模式串的匹配状态会变成什么。kmp已经回答了这个问题。
如果匹配,显然匹配状态+1。
如果失配,那么我需要快速找到能够匹配的前缀状态,这就是不断跳 \(next\) 数组的过程。
那么我们只要跑kmp就可以求出 \(nj,nk\) 了。
所以具体做法就是先跑出kmp,之后做dp,在转移过程中用跳 \(next\)。遇到 \('*'\) 暴力枚举就行。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<functional>
#include<cassert>
#include<random>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define SZ(x) (int)x.size()
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define dec(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;
int f[1010][55][55];
void solve()
{
string s,a,b; cin >> s >> a >> b;
int n = s.size(), la = a.size(), lb = b.size();
s = '$' + s;
auto getNxt = [&](string &s) {
int n = s.size();
s = '$' + s;
vector<int> nxt(n + 1);
for(int i = 2,j = 0;i <= n;i += 1) {
while(j and s[i] != s[j + 1])
j = nxt[j];
nxt[i] = (j += s[i] == s[j + 1]);
}
return nxt;
};
auto na = getNxt(a), nb = getNxt(b);
memset(f,0xcf,sizeof f);
f[0][0][0] = 0;
rep(i,1,n) {
if(s[i] != '*') {
rep(j,0,min(la - 1,i)) rep(k,0,min(lb - 1,i)) {
int prea = j, preb = k;
// 新加字符s[i] j,k 回退
while(prea and a[prea + 1] != s[i]) prea = na[prea];
while(preb and b[preb + 1] != s[i]) preb = nb[preb];
prea += a[prea + 1] == s[i];
preb += b[preb + 1] == s[i];
int d = (prea == la) - (preb == lb);
if(la == prea) prea = na[prea];
if(lb == preb) preb = nb[preb];
f[i][prea][preb] = max(f[i - 1][j][k] + d, f[i][prea][preb]);
}
}else {
rep(ch,0,25) {
char c = ch + 'a';
rep(j,0,min(la - 1,i)) rep(k,0,min(lb - 1,i)) {
int prea = j, preb = k;
s[i] = c;
// 新加字符s[i] j,k 回退
while(prea and a[prea + 1] != s[i]) prea = na[prea];
while(preb and b[preb + 1] != s[i]) preb = nb[preb];
prea += a[prea + 1] == s[i];
preb += b[preb + 1] == s[i];
int d = (prea == la) - (preb == lb);
if(la == prea) prea = na[prea];
if(lb == preb) preb = nb[preb];
f[i][prea][preb] = max(f[i - 1][j][k] + d, f[i][prea][preb]);
}
}
}
}
int ans = -linf;
rep(i,0,la) rep(j,0,lb) ans = max(ans, f[n][i][j]);
cout << ans;
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//int T;cin>>T;
//while(T--)
solve();
return 0;
}
有几个细节:
- dp数组初始化要为负无穷。
- 在跳完 \(next\) 后,如果得到 \(|a|==nj\) 。我们还要将它做一次回退。这个和跑kmp时当完成一次匹配后,要回退很类似。(具体原因还不太懂)
- 只跑 \(1 \sim |a|-1\) 的状态,为了防止越界。
优化
可以发现,求 \(nj,nk\) 这个过程和 \(s\) 具体是什么无关,它只和 \(i\) 当前匹配字符和前面的 \(next\) 有关。所以我们可以在kmp时将这个预处理出来,优化掉该过程。
如果对kmp足够熟悉,可以发现这个就是建立状态机的过程
auto getNxt = [&](string &s) {
int n = s.size();
s = '$' + s;
vector<int> nxt(n + 1);
for(int i = 2,j = 0;i <= n;i += 1) {
while(j and s[i] != s[j + 1])
j = nxt[j];
nxt[i] = (j += s[i] == s[j + 1]);
}
vector<vector<int>> nt(n + 1,vector<int>(26));
rep(i,0,n) rep(j,0,25) {
int cur = i;
while(cur and (cur == n or s[cur + 1] != j + 'a'))
cur = nxt[cur];
nt[i][j] = (cur += s[cur + 1] == j + 'a');
}
return nt;
};
从这个代码就可以看到,当匹配到字符串末尾时,很自然的可以想到需要做回退,否则会出现越界的情况。
进一步的,我们可以对 \(dp[i][j][k]\) 的意义做另一个角度诠释:考虑前 \(i\) 的字符串,kmp状态(最大前缀匹配状态)为 \(j,k\) 时的答案。可以看成在kmp状态构成的分层图(根据i的不同分层)上做dp。
于是可以将复杂度优化为 \(O(26*|s||a||b|)\) 。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<functional>
#include<cassert>
#include<random>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define SZ(x) (int)x.size()
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define dec(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;
int f[1010][55][55];
void solve()
{
string s,a,b; cin >> s >> a >> b;
int n = s.size(), la = a.size(), lb = b.size();
s = '$' + s;
auto getNxt = [&](string &s) {
int n = s.size();
s = '$' + s;
vector<int> nxt(n + 1);
for(int i = 2,j = 0;i <= n;i += 1) {
while(j and s[i] != s[j + 1])
j = nxt[j];
nxt[i] = (j += s[i] == s[j + 1]);
}
vector<vector<int>> nt(n + 1,vector<int>(26));
rep(i,0,n) rep(j,0,25) {
int cur = i;
while(cur and (cur == n or s[cur + 1] != j + 'a'))
cur = nxt[cur];
nt[i][j] = (cur += s[cur + 1] == j + 'a');
}
return nt;
};
auto na = getNxt(a), nb = getNxt(b);
memset(f,0xcf,sizeof f);
f[0][0][0] = 0;
rep(i,1,n) rep(j,0,min(la,i)) rep(k,0,min(lb,i)) if(s[i] != '*') {
int pa = na[j][s[i] - 'a'], pb = nb[k][s[i] - 'a'];
int t = f[i - 1][j][k] + (pa == la) - (pb == lb);
f[i][pa][pb] = max(t, f[i][pa][pb]);
}else rep(c,0,25) {
int pa = na[j][c], pb = nb[k][c];
int t = f[i - 1][j][k] + (pa == la) - (pb == lb);
f[i][pa][pb] = max(t, f[i][pa][pb]);
}
int ans = -linf;
rep(i,0,la) rep(j,0,lb) ans = max(ans, f[n][i][j]);
cout << ans;
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//int T;cin>>T;
//while(T--)
solve();
return 0;
}