BZOJ3160 万径人踪灭 字符串 多项式 Manachar FFT

原文链接http://www.cnblogs.com/zhouzhendong/p/8810140.html

题目传送门 - BZOJ3160

题意

  给你一个只含$a,b$的字符串,让你选择一个子序列,使得:

  $1.$位置和字符都关于某一条对称轴对称。

  $2.$不能是连续的一段。

  问原来的字符串中能找出多少个这样的子序列。答案对$10^9+7$取模。

  串长$\leq 10^5$。

题解

  下面的讨论都在满足条件$1$的情况下进行。

  首先,我们先不考虑条件$2$。然后再减掉不满足条件$2$的就可以了。

  显然不满足条件$2$的就是一个连续的回文串。我们只需要统计原串中的连续回文子串个数即可,这个可以用$Manachar$算法轻松搞定。

  考虑只满足条件$1$的。

  注意考虑第一个条件中位置也关于某一条对称轴对称。

  显然,考虑一个位置$i$,如果$s_{i-j}\neq s_{i+j}((i+j)\in N,(i-j)\in N)$,那么显然不能出现这一对字符。否则,这对字符就可以出现也可以不出现。

  如果位置$i$的两侧有$k$对可以出现也可以不出现的字符,那么对答案的贡献就是$2^k-1$。

  减$1$的原因是不能不取任何字符对,如果那样的话是空串,不满足条件。

  我们观察$s_{i-j}$和$s_{i+j}$的下标,发现$i-j+i+j=2i$考虑$j$为变量的话是个定值,那么恰好满足的多项式乘法卷积的形式。

  考虑构造多项式。

  下面引用框里面是我一开始自己想的算法,常数要大不少,可能过不去,但是也是对的。

设$f_i,g_i$分别表示的$i$位的字符值$('a'=0,'b'=1)$和的$i$位是否有效$(下标超过原来的串长则g_i=0,否则g_i=1)$。

构造卷积

$$h_i=\sum_{j=0}^{i}g_ig_{i-j}(f_i-f_{i-j})^2$$

然后展开就可以$FFT$了。

得到的$h_i$是当以$\frac i2$为对称轴的时候$\frac i2$左右不能匹配的对数。

  但是事实上有更好的做法。

  考虑$'a'$和$'b'$,分开考虑,然后合法的数对个数加起来就是总的合法数对个数了。

  这里只说$'a'$。

  定义$g_i$,表示如果$s_i='a'$则$g_i=1$,否则$g_i=0$。

  卷积$f=g^2$,即$f_i=\sum_{j=0}^i g_jg_{i-j}$。$FFT$优化即可。

  得到的$f_i$和之前的意义差不多,只是表示的是能匹配的对数了。

  最后计算答案不用说了吧QAQ。

  然而由于博主只写过$2$~$3$次$Manachar$,不够熟练,导致$Manachar$写错了,贡献了一次$TLE$,一次$WA$。

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=1<<18;
const LL mod=1e9+7;
double PI=acos(-1.0);
int m,n,L,R[N],r[N],res[N];
char s[N],str[N];
LL pow2[N],ans=0;
struct C{
	double r,i;
	C(){}
	C(double a,double b){r=a,i=b;}
	C operator + (C x){return C(r+x.r,i+x.i);}
	C operator - (C x){return C(r-x.r,i-x.i);}
	C operator * (C x){return C(r*x.r-i*x.i,r*x.i+i*x.r);}
}w[N],x[N],y[N],z[N];
void Manachar(char str[]){
	for (int i=0;i<m;i++)
		str[i*2+1]=s[i];
	for (int i=0;i<=m;i++)
		str[i*2]='*';
	int R=0,p=0;
	for (int i=1;i<m*2;i++){
		r[i]=max(1,min(r[p*2-i],R-i));
		while (i-r[i]>=0&&i+r[i]<=m*2&&str[i-r[i]]==str[i+r[i]])
			r[i]++;
		if (i+r[i]>R)
			R=i+r[i],p=i;
	}
}
void FFT(C a[]){
	for (int i=0;i<n;i++)
		if (i<R[i])
			swap(a[i],a[R[i]]);
	for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
		for (int i=0;i<n;i+=(d<<1))
			for (int j=0;j<d;j++){
				C tmp=w[t*j]*a[i+j+d];
				a[i+j+d]=a[i+j]-tmp;
				a[i+j]=a[i+j]+tmp;
			}
}
int main(){
	scanf("%s",s);
	m=strlen(s);
	Manachar(str);
	for (n=1,L=0;n<m*2;n<<=1,L++);
	for (int i=0;i<n;i++){
		R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
		w[i]=C(cos(2*i*PI/n),sin(2*i*PI/n));
		pow2[i]=i==0?1:(pow2[i-1]*2%mod);
		x[i]=C(0,0),y[i]=C(0,0);
	}
	for (int i=0;i<m;i++)
		x[i]=C(s[i]=='a',0),y[i]=C(s[i]=='b',0);
	FFT(x),FFT(y);
	for (int i=0;i<n;i++)
		x[i]=x[i]*x[i],y[i]=y[i]*y[i],w[i].i*=-1.0;
	FFT(x),FFT(y);
	for (int i=0;i<n;i++)
		res[i]=(int)((x[i].r+y[i].r)/n+0.5);
	for (int i=0;i<=m*2-2;i++)
		ans=(ans+mod+pow2[(res[i]+1)/2]-r[i+1]/2-1)%mod;
	printf("%lld",ans);
	return 0;
}

  

posted @ 2018-04-12 20:10  zzd233  阅读(453)  评论(0编辑  收藏  举报