bzoj3160: 万径人踪灭
https://www.lydsy.com/JudgeOnline/problem.php?id=3160
不连续的回文串数量=所有的回文序列数量-连续的回文子串
连续的回文子串:
manacher 得到的以i为中心的连续回文串数量=以i为中心的最长回文半径长度
所有的回文序列:
将a看做1,b看做0,自己跟自己做一遍fft
得到的a[i]就是以i/2为中心的由a构成的最长回文序列长度
将a看做0,b看做1,自己跟自己做一遍fft
得到的b[i]就是以i/2为中心的由b构成的最长回文序列长度
因为可以不连续,所以每一对以i为中心的对称位置要么同时选,要么同时不选
所以以i为中心的回文序列数量=2^(f[i]/2 [上取整])-1
#include<cmath> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int N=(1<<18)+2; const double pi=acos(-1); const int mod=1e9+7; char s[N]; int n; struct Complex { double x,y; Complex(double x_=0,double y_=0):x(x_),y(y_){} Complex operator + (Complex P) { return Complex(x+P.x,y+P.y); } Complex operator - (Complex P) { return Complex(x-P.x,y-P.y); } Complex operator * (Complex P) { return Complex(x*P.x-y*P.y,x*P.y+y*P.x); } }; typedef Complex E; E a[N],b[N]; int rev[N]; int f[N]; char t[N]; int p[N]; int Pow(int a,int b) { int res=1; for(;b;b>>=1,a=1LL*a*a%mod) if(b&1) res=1LL*res*a%mod; return res; } void fft(E *a,int len,int tag) { for(int i=0;i<len;++i) if(i<rev[i]) swap(a[i],a[rev[i]]); for(int i=1;i<len;i<<=1) { E wn(cos(pi/i),tag*sin(pi/i)); for(int p=i<<1,j=0;j<len;j+=p) { E w(1,0); for(int k=0;k<i;++k,w=w*wn) { E x=a[j+k],y=a[j+k+i]*w; a[j+k]=x+y; a[j+k+i]=x-y; } } } if(tag==-1) { for(int i=0;i<len;++i) a[i].x=(a[i].x+0.5)/len; } } int solve_all() { for(int i=0;i<n;++i) if(s[i]=='a') a[i].x+=1; else b[i].x=1; int num=n*2-1,len=1,bit=0; while(len<num) len<<=1,bit++; for(int i=0;i<len;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1); fft(a,len,1); for(int i=0;i<len;++i) a[i]=a[i]*a[i]; fft(a,len,-1); fft(b,len,1); for(int i=0;i<len;++i) b[i]=b[i]*b[i]; fft(b,len,-1); for(int i=0;i<len;++i) f[i]=a[i].x+b[i].x; int sum=0; for(int i=0;i<len;++i) { sum+=Pow(2,f[i]+1>>1)-1; sum-=sum>=mod ? mod : 0; } return sum; } void manacher(int m) { int id=0,pos=0,x=0; for(int i=1;i<=m;++i) { if(pos>i) x=min(p[id*2-i],pos-i); else x=1; while(t[i-x]==t[i+x]) x++; if(i+x>pos) pos=i+x,id=i; p[i]=x; } } int solve_continuous() { int m=0; t[m]='!'; for(int i=0;i<n;++i) { t[++m]='#'; t[++m]=s[i]; } t[++m]='#'; t[m+1]='@'; manacher(m); int sum=0; for(int i=1;i<=m;++i) { sum+=p[i]>>1; sum-=sum>=mod ? mod : 0; } return sum; } int main() { scanf("%s",s); n=strlen(s); int t1=solve_all(); int t2=solve_continuous(); printf("%d",(t1-t2+mod)%mod); }