【bzoj3160】万径人踪灭 【FFT 卷积】
题目链接
题意:给你一个只含a和b的字符串,求不连续的回文子序列的个数。
题解: 在每两个字符之间补上一个特殊字符。我们让f[i]表示i左右两边(包括i)对称(位置及字符相同)的字符的对数,则显然为回文子序列的个数。但是题目要求的是不连续的,所以最后还要进行一次manacher,减掉连续的个数。
如何求f[i]?
显然
这是一个卷积的形式。
我们弄两个多项式,第一个多项式A所有a的位置都是1,第二个多项式B左右b的位置都是1。我们用FFT把A和B都平方一下,则显然。最后的加一除以二是因为多项式平方的时候同一个系数的结果被算了两次,要去掉。
最后统计即可。
代码:
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=530005,mod=1000000007;
const double pi=3.141592653589793;
int len,n,m,rev[N],f[N],bin[N],p[N];
long long ans;
char str[N],s[N];
struct complex{
double x,y;
complex(){x=y=0;}
complex(double x,double y):x(x),y(y){}
friend complex operator + (const complex &a,const complex &b){
return complex(a.x+b.x,a.y+b.y);
}
friend complex operator - (const complex &a,const complex &b){
return complex(a.x-b.x,a.y-b.y);
}
friend complex operator * (const complex &a,const complex &b){
return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
}a[N],b[N];
void fft(complex *a,int dft){
for(int i=0;i<=n;i++){
if(i<rev[i]){
swap(a[i],a[rev[i]]);
}
}
for(int i=1;i<n;i<<=1){
complex wn=complex(cos(pi/i),dft*sin(pi/i));
for(int j=0;j<n;j+=i<<1){
complex w=complex(1,0);
for(int k=j;k<j+i;k++,w=w*wn){
complex x=a[k];
complex y=w*a[k+i];
a[k]=x+y;
a[k+i]=x-y;
}
}
}
if(dft==-1){
for(int i=0;i<=n;i++){
a[i].x/=n;
}
}
}
void manacher(){
p[0]=1;
int maxr=0,pos=0;
for(int i=1;i<=m;i++){
p[i]=i<maxr?min(p[2*pos-i],maxr-i):1;
while(i-p[i]>=0&&i+p[i]<=m&&s[i-p[i]]==s[i+p[i]]){
p[i]++;
}
if(i+p[i]>maxr){
pos=i;
maxr=i+p[i];
}
}
}
int main(){
bin[0]=1;
for(int i=1;i<N;i++){
bin[i]=2*bin[i-1]%mod;
}
scanf("%s",str+1);
len=n=strlen(str+1);
s[0]='#';
for(int i=1;i<=n;i++){
s[++m]=str[i];
s[++m]='#';
}
for(int i=0;i<=m;i++){
if(s[i]=='a'){
a[i].x=1;
}else if(s[i]=='b'){
b[i].x=1;
}
}
m*=2;
for(n=1;n<=m;n<<=1);
for(int i=0;i<=n;i++){
rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
}
fft(a,1);
for(int i=0;i<=n;i++){
a[i]=a[i]*a[i];
}
fft(a,-1);
fft(b,1);
for(int i=0;i<=n;i++){
b[i]=b[i]*b[i];
}
fft(b,-1);
m/=2;
manacher();
for(int i=1;i<=m;i++){
f[i]=int(a[i<<1].x+0.5)+int(b[i<<1].x+0.5);
f[i]=(f[i]+1)/2;
ans=ans+bin[f[i]]-1-p[i]/2;
ans=(ans+mod)%mod;
}
printf("%lld\n",ans);
return 0;
}