【BZOJ3160】万径人踪灭 - 题解
题目链接
做法
不包含连续下标的回文子序列 = 所有回文子序列 - 连续下标的回文子序列。
对于连续下标的回文子序列,可以用 $ Manacher $ 算法快速计算。
对于所有回文子序列,考虑枚举对称中心 $ r $ , 若有 $ k $ 组 $ (x, y) $ 满足 $ x \not= y $ 且 $ x + y = 2r $ 且 $ s_x = s_y $ ,那么方案数为 $ 2^{k + 1} - 1 $ ($ k $ 组 $ (x, y) $ 加上 $ r $)。若对称中心在 $ r $ 和 $ r + 1 $ 之间,此时 \((x, y)\) 应当满足 $ x + y = 2r + 1 $ , 方案数为 $ 2^k - 1 $ 。
发现寻找 $ s_x, s_y $ 可以用卷积来写,就直接 $ NTT $ 了。
时间复杂度 $ O(n \log n) $ 。
#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int N = 200010;
const int mod = 1e9 + 7;
int n, m, RL[N + N], cnt[N + N], ba[N];
char s[N], t[N + N];
int ans = 0;
vector<int> a, b, res;
inline int Min(const int &x, const int &y) { return x < y ? x : y; }
inline int Max(const int &x, const int &y) { return x > y ? x : y; }
namespace Poly {
const int mod = 998244353;
inline int add(const int &x, const int &y) {
return x + y < mod ? x + y : x + y - mod;
}
inline int sub(const int &x, const int &y) {
return x - y < 0 ? x - y + mod : x - y;
}
inline int mul(const int &x, const int &y) {
return (int)((ll) x * y % mod);
}
int ksm(int x, int y = mod - 2) {
int ss = 1;
for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ss = mul(ss, x);
return ss;
}
int Get(int x) { int ss = 1; for(; ss <= x; ss <<= 1); return ss; }
void ntt(vector<int> &a, int lmt, int opt) {
a.resize(lmt);
for(int i = 0, j = 0; i < lmt; i++) {
if(i < j) swap(a[i], a[j]);
for(int k = lmt >> 1; (j ^= k) < k; k >>= 1);
}
vector<int> w(lmt >> 1);
for(int mid = 1; mid < lmt; mid <<= 1) {
w[0] = 1;
int w0 = ksm(opt == 1 ? 3 : (mod + 1) / 3, (mod - 1) / mid / 2);
for(int i = 1; i < mid; i++) w[i] = mul(w[i - 1], w0);
for(int j = 0; j < lmt; j += mid << 1)
for(int k = 0; k < mid; k++) {
int x = a[j + k], y = mul(a[j + mid + k], w[k]);
a[j + k] = add(x, y), a[j + mid + k] = sub(x, y);
}
}
if(opt == -1)
for(int i = 0, inv = ksm(lmt); i < lmt; i++)
a[i] = mul(a[i], inv);
}
vector<int> Mul(const vector<int> &A, const vector<int> &B) {
vector<int> a = A, b = B; int lmt = Get(A.size() + B.size() - 2);
ntt(a, lmt, 1), ntt(b, lmt, 1);
for(int i = 0; i < lmt; i++) a[i] = mul(a[i], b[i]);
ntt(a, lmt, -1); return a.resize(A.size() + B.size() - 1), a;
}
}
void Manacher() {
int mx = 0, pos = 0; RL[0] = 1;
for(int i = 0; i < m; i++) {
RL[i] = mx > i ? Min(RL[pos * 2 - i], mx - i) : 1;
for(; t[i - RL[i]] == t[i + RL[i]]; ++RL[i]);
if(i + RL[i] > mx) mx = i + RL[i], pos = i;
}
for(int i = 0; i < m; i++) ans = (ans - RL[i] / 2 + mod) % mod;
}
int main() {
scanf("%s", s), n = strlen(s);
t[m++] = '*'; for(int i = 0; i < n; i++) t[m++] = '#', t[m++] = s[i];
t[m++] = '#', Manacher();
ba[0] = 1;
for(int i = 1; i <= n; i++) ba[i] = (ba[i - 1] + ba[i - 1]) % mod;
//solve a
a.clear(), b.clear(), res.clear();
for(int i = 0; i < n; i++) a.pb(s[i] == 'a' ? 1 : 0);
b = a, res = Poly::Mul(a, b), res.resize(n + n);
for(int i = 0; i < n; i++) if(s[i] == 'a') --res[i + i];
for(int i = 0; i <= n + n - 2; i++) cnt[i] += res[i] / 2;
//solve b
a.clear(), b.clear(), res.clear();
for(int i = 0; i < n; i++) a.pb(s[i] == 'b' ? 1 : 0);
b = a, res = Poly::Mul(a, b), res.resize(n + n);
for(int i = 0; i < n; i++) if(s[i] == 'b') --res[i + i];
for(int i = 0; i <= n + n - 2; i++) cnt[i] += res[i] / 2;
// for(int i = 0; i <= n + n - 2; i++)
// printf(">>> %d -> %d : %d\n", i / 2, (i + 1) / 2, cnt[i]);
for(int i = 0; i <= n + n - 2; i++) {
if(i & 1) ans = ((ans + ba[cnt[i]]) % mod + mod - 1) % mod;
else ans = ((ans + ba[cnt[i] + 1]) % mod + mod - 1) % mod;
}
printf("%d\n", ans);
return 0;
}