「NOI2016」优秀的拆分

传送门

观察到这题 \(O(n^2)\)\(95\) 分,不妨先想想 \(O(n^2)\) 怎么做。

如果我们用 \(f_i\) 表示以 \(i\) 结尾的 \(\text{AA}\) 串的个数,\(g_i\) 表示以 \(i\) 开头的 \(\text{AA}\) 串的个数,那么答案就是:

\[\sum_{i = 1} ^ {n - 1} f_i \times g_{i + 1} \]

直接计算 \(f_i, g_i\)\(O(n ^ 2)\) 的,不过居然有 \(95\) 分啊 \(Q \omega Q\)

考虑如何较快的计算 \(f_i\)\(g_i\)

我们可以这么处理:枚举 \(\text{AA}\) 串中 \(\text{A}\) 的长度 \(len\) ,然后在原串中每 \(len\) 个位置标记一个关键点,这样一来出现在原串中的 \(\text{AA}\) 串无论如何都会经过且只经过两个相邻的关键点。

有一说一 这谁想得到啊

那么我们就考虑用两个相邻的关键点来算 \(\text{AA}\) 的个数。

对于两个相邻的关键点 \(i, j\),如果我们求出 \(\text{LCS}(i - 1, j - 1)\)\(\text{LCP}(i, j)\),分别记为 \(lcs, lcp\),那么我们就只需要考察 \(lcs\)\(lcp\) 在区间 \([i, j]\) 内的交的情况就好了(此处我们认为贴在一起也算有交)

结合一个图来看:

\[\underbrace{.......i-1}_{lcs}\;\overbrace{\underbrace{i........}_{lcp}\;....\underbrace{.......j-1}_{lcs}}^{len}\;\underbrace{j........}_{lcp} \]

如果 \(lcs\)\(lcp\) 有交的话,那么右端点在这个交里的 \(\text{A}\) 都可以在右边复制一遍形成一个 \(\text{AA}\) 串。

具体原理看看图,自己想一想。

那么也就是说,我们每次这么构造出来的 \(\text{AA}\) 串都是连续分布的,我们就考虑差分统计,最后再像之前那样算答案就好了。

参考代码:

#include <cstring>
#include <cstdio>

int min(int a, int b) { return a < b ? a : b; } 

typedef long long LL;
const int _ = 1e5 + 5, base = 131, mod1 = 1e9 + 7, mod2 = 1e9 + 9;

char s[_]; int n, f[_], g[_];
int pw1[_], hash1[_], pw2[_], hash2[_];

int Hash1(int l, int r) { return (mod1 + hash1[r] - 1ll * hash1[l - 1] * pw1[r - l + 1] % mod1) % mod1; }

int Hash2(int l, int r) { return (mod2 + hash2[r] - 1ll * hash2[l - 1] * pw2[r - l + 1] % mod2) % mod2; }

int chk(int a, int b, int c, int d) { return Hash1(a, b) == Hash1(c, d) && Hash2(a, b) == Hash2(c, d); }

int LCP(int i, int j) {
    int l = 0, r = min(n - i + 1, n - j + 1);
    while (l < r) {
        int mid = (l + r + 1) >> 1;
        if (chk(i, i + mid - 1, j, j + mid - 1)) l = mid; else r = mid - 1;
    }
    return l;
}

int LCS(int i, int j) {
    int l = 0, r = min(i, j);
    while (l < r) {
        int mid = (l + r + 1) >> 1;
        if (chk(i - mid + 1, i, j - mid + 1, j)) l = mid; else r = mid - 1;
    }
    return l;
}

void solve() {
    memset(hash1, 0, sizeof hash1);
    memset(hash2, 0, sizeof hash2);
    memset(f, 0, sizeof f);
    memset(g, 0, sizeof g);
    scanf("%s", s + 1), n = strlen(s + 1);
    for (int i = 1; i <= n; ++i) {
        hash1[i] = (1ll * hash1[i - 1] * base % mod1 + s[i]) % mod1;
        hash2[i] = (1ll * hash2[i - 1] * base % mod2 + s[i]) % mod2;
    }
    for (int len = 1; len << 1 <= n; ++len)
        for (int i = len, j = i + len; j <= n; i += len, j += len) {
            int lcp = min(len, LCP(i, j)), lcs = min(len - 1, LCS(i - 1, j - 1));
            int cap = lcp + lcs - len + 1;
            if (lcp + lcs >= len) {
                ++g[i - lcs], --g[i - lcs + cap];
                ++f[j + lcp - cap], --f[j + lcp];
            }
        }
    for (int i = 1; i <= n; ++i) f[i] += f[i - 1], g[i] += g[i - 1];
    LL ans = 0;
    for (int i = 1; i < n; ++i) ans += 1ll * f[i] * g[i + 1];
    printf("%lld\n", ans);
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
    pw1[0] = pw2[0] = 1;
    for (int i = 1; i < _; ++i) {
        pw1[i] = 1ll * pw1[i - 1] * base % mod1;
        pw2[i] = 1ll * pw2[i - 1] * base % mod2;
    }
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}
posted @ 2020-06-11 20:35  Sangber  阅读(246)  评论(0编辑  收藏  举报