「NOI2016」优秀的拆分
观察到这题 \(O(n^2)\) 有 \(95\) 分,不妨先想想 \(O(n^2)\) 怎么做。
如果我们用 \(f_i\) 表示以 \(i\) 结尾的 \(\text{AA}\) 串的个数,\(g_i\) 表示以 \(i\) 开头的 \(\text{AA}\) 串的个数,那么答案就是:
直接计算 \(f_i, g_i\) 是 \(O(n ^ 2)\) 的,不过居然有 \(95\) 分啊 \(Q \omega Q\)
考虑如何较快的计算 \(f_i\) 和 \(g_i\)。
我们可以这么处理:枚举 \(\text{AA}\) 串中 \(\text{A}\) 的长度 \(len\) ,然后在原串中每 \(len\) 个位置标记一个关键点,这样一来出现在原串中的 \(\text{AA}\) 串无论如何都会经过且只经过两个相邻的关键点。
有一说一 这谁想得到啊
那么我们就考虑用两个相邻的关键点来算 \(\text{AA}\) 的个数。
对于两个相邻的关键点 \(i, j\),如果我们求出 \(\text{LCS}(i - 1, j - 1)\) 和 \(\text{LCP}(i, j)\),分别记为 \(lcs, lcp\),那么我们就只需要考察 \(lcs\) 和 \(lcp\) 在区间 \([i, j]\) 内的交的情况就好了(此处我们认为贴在一起也算有交)
结合一个图来看:
如果 \(lcs\) 和 \(lcp\) 有交的话,那么右端点在这个交里的 \(\text{A}\) 都可以在右边复制一遍形成一个 \(\text{AA}\) 串。
具体原理看看图,自己想一想。
那么也就是说,我们每次这么构造出来的 \(\text{AA}\) 串都是连续分布的,我们就考虑差分统计,最后再像之前那样算答案就好了。
参考代码:
#include <cstring>
#include <cstdio>
int min(int a, int b) { return a < b ? a : b; }
typedef long long LL;
const int _ = 1e5 + 5, base = 131, mod1 = 1e9 + 7, mod2 = 1e9 + 9;
char s[_]; int n, f[_], g[_];
int pw1[_], hash1[_], pw2[_], hash2[_];
int Hash1(int l, int r) { return (mod1 + hash1[r] - 1ll * hash1[l - 1] * pw1[r - l + 1] % mod1) % mod1; }
int Hash2(int l, int r) { return (mod2 + hash2[r] - 1ll * hash2[l - 1] * pw2[r - l + 1] % mod2) % mod2; }
int chk(int a, int b, int c, int d) { return Hash1(a, b) == Hash1(c, d) && Hash2(a, b) == Hash2(c, d); }
int LCP(int i, int j) {
int l = 0, r = min(n - i + 1, n - j + 1);
while (l < r) {
int mid = (l + r + 1) >> 1;
if (chk(i, i + mid - 1, j, j + mid - 1)) l = mid; else r = mid - 1;
}
return l;
}
int LCS(int i, int j) {
int l = 0, r = min(i, j);
while (l < r) {
int mid = (l + r + 1) >> 1;
if (chk(i - mid + 1, i, j - mid + 1, j)) l = mid; else r = mid - 1;
}
return l;
}
void solve() {
memset(hash1, 0, sizeof hash1);
memset(hash2, 0, sizeof hash2);
memset(f, 0, sizeof f);
memset(g, 0, sizeof g);
scanf("%s", s + 1), n = strlen(s + 1);
for (int i = 1; i <= n; ++i) {
hash1[i] = (1ll * hash1[i - 1] * base % mod1 + s[i]) % mod1;
hash2[i] = (1ll * hash2[i - 1] * base % mod2 + s[i]) % mod2;
}
for (int len = 1; len << 1 <= n; ++len)
for (int i = len, j = i + len; j <= n; i += len, j += len) {
int lcp = min(len, LCP(i, j)), lcs = min(len - 1, LCS(i - 1, j - 1));
int cap = lcp + lcs - len + 1;
if (lcp + lcs >= len) {
++g[i - lcs], --g[i - lcs + cap];
++f[j + lcp - cap], --f[j + lcp];
}
}
for (int i = 1; i <= n; ++i) f[i] += f[i - 1], g[i] += g[i - 1];
LL ans = 0;
for (int i = 1; i < n; ++i) ans += 1ll * f[i] * g[i + 1];
printf("%lld\n", ans);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
pw1[0] = pw2[0] = 1;
for (int i = 1; i < _; ++i) {
pw1[i] = 1ll * pw1[i - 1] * base % mod1;
pw2[i] = 1ll * pw2[i - 1] * base % mod2;
}
int T; scanf("%d", &T);
while (T--) solve();
return 0;
}