CF1562E - Rescue Niwen!(字符串哈希,dp)

source

题解

\(s(l,r)\)\(s\)区间\((l,r)\)的子串,\(s_p\)代表\(s\)\(p\)处的字符。通过观察(小数据/推导/瞎猜/看题解)可以发现,如果最优解中含有\(s(l,r)\),那么\(s(l,n)\)必然包含在最优解中。
证明:
假设答案中包含相邻的\(s(i,r_i)\)\(s(j,r_j)\)。(\(i<j且s(i,r_i)<s(j,r_j)\))假设后缀\(i\)和后缀\(j\)的公共前缀长度为\(l\)。那么有:

  1. 如果\(s_{r_i} \ge s_i + l\),那么有:\(s_{r_j}=s_j + l\)\(s_{r_i}=n\)。这样是最优的,因为这样\(s(i,r_i)<s(j,r_j)\)恒成立,不会对后面的序列有影响。
  2. 如果\(s_{r_i} < s_i + l\),意味着\(s_{r_j}=s_{r_i}+1\),不会比情况1更优。

\(r_i=n\)是最优的,即如果最优解中含有\(s(l,r)\),那么\(s(l,n)\)必然包含在最优解中。

知道这个结论后,就简单了。问题变为选后缀,直接dp计算。使用字符串哈希计算两个哈希的lcp,从而可以得到一个后缀接到另一个后缀前面时第一个前缀的长度是多少了。时间复杂度\(O(n^2\log n)\)

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 5e3 + 10;
const int M = 998244353;
const double eps = 1e-5;
const int base = 255;
const int rbase = 461932681;
int len[N][N], dp[N];
char s[N];
ull val[N];
ll rpw[N];
int n;

ull cal(int l, int r) {
    return (val[r] - val[l - 1] + M) * rpw[l - 1] % M;
}

int lcp(int p1, int p2) {
    int len = min(n - p1 + 1, n - p2 + 1);
    int l = 1, r = len;
    while(l <= r) {
        int mid = (l + r) / 2;
        if(cal(p1, p1 + mid - 1) != cal(p2, p2 + mid - 1)) {
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    return r;
}

int main() {
    rpw[0] = 1;
    for(int i = 1; i < N; i++) rpw[i] = rpw[i - 1] * rbase % M;
    IOS;
    int t;
    cin >> t;
    while(t--) {
        cin >> n;
        cin >> s + 1;
        ull pw = 1;
        for(int i = 1; i <= n; i++) {
            val[i] = (val[i - 1] + pw * s[i]) % M;
            pw = pw * base % M;
        }
        dp[1] = n;
        for(int i = 2; i <= n; i++) {
            dp[i] = n - i + 1;
            for(int j = i - 1; j >= 1; j--) {
                int len = lcp(j, i);
                int p1 = i + len, p2 = j + len;
                if(s[p1] > s[p2]) {
                    dp[i] = max(dp[i], dp[j] + n - p1 + 1);
                }
            }
        }
        int ans = 0;
        for(int i = 1; i <= n; i++) {
            ans = max(ans, dp[i]);
        }
        cout << ans << endl;
    }
}
posted @ 2021-11-02 20:29  limil  阅读(52)  评论(0编辑  收藏  举报