bzoj 3160: 万径人踪灭
Description
Solution
由于回文串要不连续 , 我们可以用总方案-连续的方案 , 而连续的方案就是回文串的个数 , 可以用 \(manacher\) 求出 .
对于总方案 , 我们枚举一个回文中心 , 要么是空隙 , 要么是一个位置 .
设以这个点对称的相等字符对有 \(k\) 对 , 对答案的贡献就是 \(2^k-1\) .
这个用 \(FFT\) 卷积一下就可以得到每一个点作为中心的相等字符对了 .
注意常数优化 :
首先枚举空隙不需要倍长序列 , 其对应卷积后奇数位的答案 , 常数可以除 \(2\) .
注意要手写 \(complex\) , 也可以预处理一下单位根 .
#include<bits/stdc++.h>
#define RG register
using namespace std;
const int N=2e5+10,mod=1e9+7;const double pi=acos(-1.0);
inline int qm(int x,int k){
int sum=1;
for(;k;k>>=1,x=1ll*x*x%mod)if(k&1)sum=1ll*sum*x%mod;
return sum;
}
struct dob{
double x,y;//x+y*sqrt(-1)
dob(){}
dob(double _x,double _y){x=_x;y=_y;}
inline dob operator *(dob &p){
return dob(x*p.x-y*p.y,x*p.y+y*p.x);
}
inline dob operator +(dob &p){
return dob(x+p.x,y+p.y);
}
inline dob operator -(dob &p){
return dob(x-p.x,y-p.y);
}
};
char S[N],s[N];int m,L,R[N*2],a[N*2],w[N],n=0,f[N],ans=0,cnt;dob A[N*2];
inline void FFT(int o){
for(RG int i=0;i<m;i++)if(i<R[i])swap(A[i],A[R[i]]);
for(RG int i=1;i<m;i<<=1){
RG dob w0(cos(pi/i),sin(o*pi/i)),x,y;
for(RG int j=0;j<m;j+=i<<1){
RG dob w(1,0);
for(RG int k=0;k<i;k++,w=w*w0){
x=A[j+k];y=A[j+k+i]*w;
A[j+k]=x+y;A[j+k+i]=x-y;
}
}
}
}
inline void calc(){
for(m=1,L=0;m<=(cnt<<1);m<<=1)L++;
for(RG int i=0;i<m;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)),A[i]=dob(a[i],0);
FFT(1);
for(RG int i=0;i<m;i++)A[i]=A[i]*A[i];
FFT(-1);
for(RG int i=1;i<m;i++)
if(i&1)w[(i>>1)+cnt]+=(int)(A[i].x/m+0.5);
else w[i>>1]+=(int)(A[i].x/m+0.5);
}
inline void manacher(){
int id=1;
for(int i=1;i<=n;i++){
if(i<id+f[id])f[i]=min(id+f[id]-i,f[id*2-i]);
else f[i]=1;
while(s[i+f[i]]==s[i-f[i]])++f[i];
if(f[i]>f[id])id=i;
}
}
int main(){
freopen("pp.in","r",stdin);
freopen("pp.out","w",stdout);
scanf("%s",S+1);
cnt=strlen(S+1);
for(int i=1;i<=cnt;i++)s[++n]='#',s[++n]=S[i];
s[0]='x';s[++n]='#';s[n+1]='y';
for(int i=1;i<=cnt;i++)a[i]=(S[i]=='a');
calc();
for(int i=1;i<=cnt;i++)a[i]=(S[i]=='b');
calc();
for(int i=1;i<=cnt;i++)ans=(ans+qm(2,(w[i]>>1)+1)-1)%mod;
for(int i=cnt+1;i<=n;i++)ans=(ans+qm(2,w[i]>>1)-1)%mod;
manacher();
for(int i=1;i<=n;i++)ans=(ans-(f[i]>>1)+mod)%mod;
if(ans<0)ans+=mod;
cout<<ans;
return 0;
}