【BZOJ3160】万径人踪灭 - 题解

题目链接

【BZOJ3160】万径人踪灭

做法

不包含连续下标的回文子序列 = 所有回文子序列 - 连续下标的回文子序列。
对于连续下标的回文子序列,可以用 $ Manacher $ 算法快速计算。
对于所有回文子序列,考虑枚举对称中心 $ r $ , 若有 $ k $ 组 $ (x, y) $ 满足 $ x \not= y $ 且 $ x + y = 2r $ 且 $ s_x = s_y $ ,那么方案数为 $ 2^{k + 1} - 1 $ ($ k $ 组 $ (x, y) $ 加上 $ r $)。若对称中心在 $ r $ 和 $ r + 1 $ 之间,此时 \((x, y)\) 应当满足 $ x + y = 2r + 1 $ , 方案数为 $ 2^k - 1 $ 。
发现寻找 $ s_x, s_y $ 可以用卷积来写,就直接 $ NTT $ 了。
时间复杂度 $ O(n \log n) $ 。

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int N = 200010;
const int mod = 1e9 + 7;
int n, m, RL[N + N], cnt[N + N], ba[N];
char s[N], t[N + N];
int ans = 0;
vector<int> a, b, res;

inline int Min(const int &x, const int &y) { return x < y ? x : y; }
inline int Max(const int &x, const int &y) { return x > y ? x : y; }
namespace Poly {
	const int mod = 998244353;
	inline int add(const int &x, const int &y) {
		return x + y < mod ? x + y : x + y - mod;
	}
	inline int sub(const int &x, const int &y) {
		return x - y < 0 ? x - y + mod : x - y;
	}
	inline int mul(const int &x, const int &y) {
		return (int)((ll) x * y % mod);
	}
	int ksm(int x, int y = mod - 2) {
		int ss = 1;
		for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ss = mul(ss, x);
		return ss;
	}
	int Get(int x) { int ss = 1; for(; ss <= x; ss <<= 1); return ss; }
	void ntt(vector<int> &a, int lmt, int opt) {
		a.resize(lmt);
		for(int i = 0, j = 0; i < lmt; i++) {
			if(i < j) swap(a[i], a[j]);
			for(int k = lmt >> 1; (j ^= k) < k; k >>= 1);
		}
		vector<int> w(lmt >> 1);
		for(int mid = 1; mid < lmt; mid <<= 1) {
			w[0] = 1;
			int w0 = ksm(opt == 1 ? 3 : (mod + 1) / 3, (mod - 1) / mid / 2);
			for(int i = 1; i < mid; i++) w[i] = mul(w[i - 1], w0);
			for(int j = 0; j < lmt; j += mid << 1)
				for(int k = 0; k < mid; k++) {
					int x = a[j + k], y = mul(a[j + mid + k], w[k]);
					a[j + k] = add(x, y), a[j + mid + k] = sub(x, y);
				}
		}
		if(opt == -1)
			for(int i = 0, inv = ksm(lmt); i < lmt; i++)
				a[i] = mul(a[i], inv);
	}
	vector<int> Mul(const vector<int> &A, const vector<int> &B) {
		vector<int> a = A, b = B; int lmt = Get(A.size() + B.size() - 2);
		ntt(a, lmt, 1), ntt(b, lmt, 1);
		for(int i = 0; i < lmt; i++) a[i] = mul(a[i], b[i]);
		ntt(a, lmt, -1); return a.resize(A.size() + B.size() - 1), a;
	}
}

void Manacher() {
	int mx = 0, pos = 0; RL[0] = 1;
	for(int i = 0; i < m; i++) {
		RL[i] = mx > i ? Min(RL[pos * 2 - i], mx - i) : 1;
		for(; t[i - RL[i]] == t[i + RL[i]]; ++RL[i]);
		if(i + RL[i] > mx) mx = i + RL[i], pos = i;
	}
	for(int i = 0; i < m; i++) ans = (ans - RL[i] / 2 + mod) % mod;
}
int main() {
	scanf("%s", s), n = strlen(s);
	t[m++] = '*'; for(int i = 0; i < n; i++) t[m++] = '#', t[m++] = s[i];
	t[m++] = '#', Manacher();
	ba[0] = 1;
	for(int i = 1; i <= n; i++) ba[i] = (ba[i - 1] + ba[i - 1]) % mod;
	//solve a
	a.clear(), b.clear(), res.clear();
	for(int i = 0; i < n; i++) a.pb(s[i] == 'a' ? 1 : 0);
	b = a, res = Poly::Mul(a, b), res.resize(n + n);
	for(int i = 0; i < n; i++) if(s[i] == 'a') --res[i + i];
	for(int i = 0; i <= n + n - 2; i++) cnt[i] += res[i] / 2;
	//solve b
	a.clear(), b.clear(), res.clear();
	for(int i = 0; i < n; i++) a.pb(s[i] == 'b' ? 1 : 0);
	b = a, res = Poly::Mul(a, b), res.resize(n + n);
	for(int i = 0; i < n; i++) if(s[i] == 'b') --res[i + i];
	for(int i = 0; i <= n + n - 2; i++) cnt[i] += res[i] / 2;
//	for(int i = 0; i <= n + n - 2; i++)
//		printf(">>> %d -> %d : %d\n", i / 2, (i + 1) / 2, cnt[i]);
	for(int i = 0; i <= n + n - 2; i++) {
		if(i & 1) ans = ((ans + ba[cnt[i]]) % mod + mod - 1) % mod;
		else ans = ((ans + ba[cnt[i] + 1]) % mod + mod - 1) % mod;
	}
	printf("%d\n", ans);
	return 0;
}
posted @ 2019-04-13 13:36  daniel14311531  阅读(192)  评论(0编辑  收藏  举报