BZOJ 3160: 万径人踪灭

考虑所有回文子序列减去连续回文子序列
后者可以manacher求出
\(f_i\) 为以 \(\dfrac{i}{2}\) 为中心的回文子序列个数
若以 \(\dfrac{i}{2}\) 为中心左右有 \(x_i\) 对对称的字符
\(f_i = 2^{x_i} - 1\)
所有 \(x_i\) 可以由FFT求出
\(x_i = \sum\limits_{j+k=i}[s_j = s_k]\)
\(a\)\(b\) 分开单独考虑,记得到的卷积序列分别为 \(A\)\(B\)
\(x_i = \lceil \dfrac{A + B}{2} \rceil\)

#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define pii pair<int, int>
#define pli pair<ll, int>
#define lp p << 1
#define rp p << 1 | 1
#define mid ((l + r) / 2)
#define lowbit(i) ((i) & (-i))
#define ll long long
#define ull unsigned long long
#define db double
#define rep(i,a,b) for(int i=a;i<b;i++)
#define per(i,a,b) for(int i=b-1;i>=a;i--)
#define Edg int ccnt=1,head[N],to[E],ne[E];void addd(int u,int v){to[++ccnt]=v;ne[ccnt]=head[u];head[u]=ccnt;}void add(int u,int v){addd(u,v);addd(v,u);}
#define Edgc int ccnt=1,head[N],to[E],ne[E],c[E];void addd(int u,int v,int w){to[++ccnt]=v;ne[ccnt]=head[u];c[ccnt]=w;head[u]=ccnt;}void add(int u,int v,int w){addd(u,v,w);addd(v,u,w);}
#define es(u,i,v) for(int i=head[u],v=to[i];i;i=ne[i],v=to[i])
const int MOD = 1e9 + 7;
void M(int &x) {if (x >= MOD)x -= MOD; if (x < 0)x += MOD;}
int qp(int a, int b = MOD - 2) {int ans = 1; for (; b; a = 1LL * a * a % MOD, b >>= 1)if (b & 1)ans = 1LL * ans * a % MOD; return ans % MOD;}
template<class T>T gcd(T a, T b) { while (b) { a %= b; std::swap(a, b); } return a; }
template<class T>bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<class T>bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline char getc() {
	return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
}
inline int _() {
	int x = 0, f = 1; char ch = getc();
	while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getc(); }
	while (ch >= '0' && ch <= '9') { x = x * 10ll + ch - 48; ch = getc(); }
	return x * f;
}

const int N = 1 << 18;
char s[N];
int n, a[N], b[N], bin[1 << 18];

namespace FFT {
static const int L = 1 << 18;
const db pi = acos(-1.0);
struct Complex {
	db r, i;
	Complex() {}
	Complex(db r, db i): r(r), i(i) {}
	Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
	Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
	Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
} A[L], B[L];
int n, l, r[L];
void init(int m) {
	n = 1, l = 0;
	while (n <= m) n <<= 1, l++;
	for (int i = 0; i < n; i++) {
		r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
	}
}
void FFT(Complex *a, int pd) {
	for (int i = 0; i < n; i++)
		if (i < r[i])
			std::swap(a[i], a[r[i]]);
	for (int mi = 1; mi < n; mi <<= 1) {
		Complex wn(cos(pi / mi), pd * sin(pi / mi));
		for (int l = mi << 1, j = 0; j < n; j += l) {
			Complex w(1.0, 0.0);
			for (int k = 0; k < mi; k++, w = w * wn) {
				Complex u = a[k + j], v = w * a[k + j + mi];
				a[k + j] = u + v;
				a[k + j + mi] = u - v;
			}
		}
	}
	if (pd == -1)
		for (int i = 0; i < n; i++)
			a[i] = Complex(a[i].r / n, a[i].i / n);
}
int solve(int *a, int *b) {
	rep (i, 0, n) A[i] = Complex(1.0 * a[i], 0), B[i] = Complex(1.0 * b[i], 0);
	FFT(A, 1); FFT(B, 1);
	rep (i, 0, n) A[i] = A[i] * A[i] + B[i] * B[i];
	FFT(A, -1);
	int ans = 0;
	bin[0] = 1;
	rep (i, 1, L) bin[i] = bin[i - 1] * 2LL % MOD;
	rep (i, 0, n) {
		int cnt = int(A[i].r + 0.5);
		M(ans += bin[cnt + 1 >> 1] - 1);
	}
	return ans;
}
}

char mp[N * 2];
int ma[N * 2];

int manacher() {
	int ans = 0;
	int mx = 0, id = 0;
	int l = 0;
	mp[l] = '@'; mp[++l] = '#';
	rep (i, 0, n) mp[++l] = s[i], mp[++l] = '#';
	rep (i, 1, l + 1) {
		ma[i] = (mx > i) ? (std::min(mx - i, ma[id * 2 - i])) : 1;
		while (mp[i - ma[i]] == mp[i + ma[i]]) ma[i]++;
		if (i + ma[i] > mx) mx = i + ma[i], id = i;
		M(ans += ma[i] >> 1);
	}
	return ans;
}

int main() {
#ifdef LOCAL
	freopen("ans.out", "w", stdout);
#endif
	scanf("%s", s);
	n = strlen(s);
	rep (i, 0, n) {
		a[i] = s[i] == 'a' ? 1 : 0;
		b[i] = s[i] == 'b' ? 1 : 0;
	}
	FFT::init(2 * n + 1);
	int ans = FFT::solve(a, b);
	M(ans -= manacher());
	printf("%d\n", ans);
	return 0;
}
posted @ 2020-03-03 21:46  Mrzdtz220  阅读(114)  评论(0编辑  收藏  举报