CF1163D(DP+KMP)

CF1163D(DP+KMP)

Problem - D - Codeforces

题意

给三个串 \(s,a,b\) 。最大化 \(f(s,a)-f(s,b)\)

\(f(s,a)\) 表示 \(a\)\(s\) 中出现的次数(可重复)

\(|s| \le 1000, |a|,|b| \le 50\)

思路

解最优化问题,可以考虑 \(dp\)

对字符串做 \(dp\) ,我们可以根据匹配状态定义状态。

定义 \(dp[i][j][k]\) 表示考虑 \(s\)\(i\) 个字符,分别后缀最大匹配了 \(a,b\) 的前缀状态是 \(j,k\) 时的答案。

那么 \(max(dp[|s|][j][k])\) 就是答案。

考虑状态 \((i-1,j,k)\) 的转移。当我们新加入一个字符 \(s_i\) 时,\(a,b\) 新的匹配状态设为 \(nj,nk\) 。此时需要考虑:


当母串新纳入一个字符后,模式串的匹配状态会变成什么。kmp已经回答了这个问题。

如果匹配,显然匹配状态+1。

如果失配,那么我需要快速找到能够匹配的前缀状态,这就是不断跳 \(next\) 数组的过程。


那么我们只要跑kmp就可以求出 \(nj,nk\) 了。

\[f[i][nj][nk]=max(f[i][nj][nk],f[i][j][k]+(nj==|a|)-(nk==|b|)) \]

所以具体做法就是先跑出kmp,之后做dp,在转移过程中用跳 \(next\)。遇到 \('*'\) 暴力枚举就行。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<functional>
#include<cassert>
#include<random>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define SZ(x) (int)x.size()
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define dec(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;

int f[1010][55][55];
void solve()
{    
    string s,a,b; cin >> s >> a >> b;
    int n = s.size(), la = a.size(), lb = b.size();
    
    s = '$' + s;
    
    auto getNxt = [&](string &s) {
        int n = s.size();
        s = '$' + s;
        vector<int> nxt(n + 1);
        for(int i = 2,j = 0;i <= n;i += 1) {
            while(j and s[i] != s[j + 1])
                j = nxt[j];
            nxt[i] = (j += s[i] == s[j + 1]);
        }
        return nxt;
    };

    auto na = getNxt(a), nb = getNxt(b);

    memset(f,0xcf,sizeof f);
    f[0][0][0] = 0;

    rep(i,1,n) {
        if(s[i] != '*') {
            rep(j,0,min(la - 1,i)) rep(k,0,min(lb - 1,i)) {
                int prea = j, preb = k;
                // 新加字符s[i] j,k 回退
                while(prea and a[prea + 1] != s[i]) prea = na[prea];
                while(preb and b[preb + 1] != s[i]) preb = nb[preb];
                prea += a[prea + 1] == s[i];
                preb += b[preb + 1] == s[i];
                int d = (prea == la) - (preb == lb);
                if(la == prea) prea = na[prea];
                if(lb == preb) preb = nb[preb];
                f[i][prea][preb] = max(f[i - 1][j][k] + d, f[i][prea][preb]);
            }
        }else {
            rep(ch,0,25) {
                char c = ch + 'a';
                rep(j,0,min(la - 1,i)) rep(k,0,min(lb - 1,i)) {
                    int prea = j, preb = k;
                    s[i] = c;
                    // 新加字符s[i] j,k 回退
                    while(prea and a[prea + 1] != s[i]) prea = na[prea];
                    while(preb and b[preb + 1] != s[i]) preb = nb[preb];
                    prea += a[prea + 1] == s[i];
                    preb += b[preb + 1] == s[i];
                    int d = (prea == la) - (preb == lb);
                    if(la == prea) prea = na[prea];
                    if(lb == preb) preb = nb[preb];
                    f[i][prea][preb] = max(f[i - 1][j][k] + d, f[i][prea][preb]);
                }    
            }
        }
    }
    int ans = -linf;
    rep(i,0,la) rep(j,0,lb) ans = max(ans, f[n][i][j]);
    cout << ans;

}
signed main()
{
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);

    //int T;cin>>T;
    //while(T--)
        solve();

    return 0;
}

有几个细节:

  1. dp数组初始化要为负无穷。
  2. 在跳完 \(next\) 后,如果得到 \(|a|==nj\) 。我们还要将它做一次回退。这个和跑kmp时当完成一次匹配后,要回退很类似。(具体原因还不太懂)
  3. 只跑 \(1 \sim |a|-1\) 的状态,为了防止越界。

优化

可以发现,求 \(nj,nk\) 这个过程和 \(s\) 具体是什么无关,它只和 \(i\) 当前匹配字符和前面的 \(next\) 有关。所以我们可以在kmp时将这个预处理出来,优化掉该过程。


如果对kmp足够熟悉,可以发现这个就是建立状态机的过程


auto getNxt = [&](string &s) {
        int n = s.size();
        s = '$' + s;
        vector<int> nxt(n + 1);
        for(int i = 2,j = 0;i <= n;i += 1) {
            while(j and s[i] != s[j + 1])
                j = nxt[j];
            nxt[i] = (j += s[i] == s[j + 1]);
        }
        vector<vector<int>> nt(n + 1,vector<int>(26));
        rep(i,0,n) rep(j,0,25) {
            int cur = i;
            while(cur and (cur == n or s[cur + 1] != j + 'a'))
                cur = nxt[cur];
            nt[i][j] = (cur += s[cur + 1] == j + 'a');
        }
        return nt;
    };

从这个代码就可以看到,当匹配到字符串末尾时,很自然的可以想到需要做回退,否则会出现越界的情况。


进一步的,我们可以对 \(dp[i][j][k]\) 的意义做另一个角度诠释:考虑前 \(i\) 的字符串,kmp状态(最大前缀匹配状态)为 \(j,k\) 时的答案。可以看成在kmp状态构成的分层图(根据i的不同分层)上做dp。


于是可以将复杂度优化为 \(O(26*|s||a||b|)\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<functional>
#include<cassert>
#include<random>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define SZ(x) (int)x.size()
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define dec(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;

int f[1010][55][55];

void solve()
{    
    string s,a,b; cin >> s >> a >> b;
    int n = s.size(), la = a.size(), lb = b.size();
    
    s = '$' + s;
    
    auto getNxt = [&](string &s) {
        int n = s.size();
        s = '$' + s;
        vector<int> nxt(n + 1);
        for(int i = 2,j = 0;i <= n;i += 1) {
            while(j and s[i] != s[j + 1])
                j = nxt[j];
            nxt[i] = (j += s[i] == s[j + 1]);
        }
        vector<vector<int>> nt(n + 1,vector<int>(26));
        rep(i,0,n) rep(j,0,25) {
            int cur = i;
            while(cur and (cur == n or s[cur + 1] != j + 'a'))
                cur = nxt[cur];
            nt[i][j] = (cur += s[cur + 1] == j + 'a');
        }
        return nt;
    };

    auto na = getNxt(a), nb = getNxt(b);

    memset(f,0xcf,sizeof f);
    f[0][0][0] = 0;

    rep(i,1,n) rep(j,0,min(la,i)) rep(k,0,min(lb,i)) if(s[i] != '*') {
        int pa = na[j][s[i] - 'a'], pb = nb[k][s[i] - 'a'];
        int t = f[i - 1][j][k] + (pa == la) - (pb == lb);
        f[i][pa][pb] = max(t, f[i][pa][pb]);
    }else rep(c,0,25) {
        int pa = na[j][c], pb = nb[k][c];
        int t = f[i - 1][j][k] + (pa == la) - (pb == lb);
        f[i][pa][pb] = max(t, f[i][pa][pb]);
    }
    
    int ans = -linf;
    rep(i,0,la) rep(j,0,lb) ans = max(ans, f[n][i][j]);
    cout << ans;

}
signed main()
{
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);

    //int T;cin>>T;
    //while(T--)
        solve();

    return 0;
}
posted @ 2022-09-15 20:44  Mxrurush  阅读(16)  评论(0编辑  收藏  举报