BZOJ 3160 万径人踪灭
题目链接:万径人踪灭
因为manachar写挂导致这道题调了好久……整个人都不好了……
我们可以发现我们要求回文子序列的数目,并且要求不连续。那么我们显然可以用所有的数目减去连续的数目。
然后,我们可以非常轻易地发现(我怎么看不出来)如果我们只考虑一个字母的话,把这个串看成一个多项式,平方一下,第\(x\)项的系数就是以第\(x\)个位置(两个字母之间的间隙也算一个位置)为对称轴的对称的字母个数。题中给出的串只有\(a\)、\(b\)两种字母,那么做两遍就得到了关于每个位置\(x\)对称的字母数\(c_x\)。
由于我们选的是子序列,那么肯定每个位置都可选可不选(全部不选除外)。因此,这一部分的答案就是\(\sum2^{c_x}-1\)。
然后我们要减去不合法的情况。就是对每个位置求出最长回文串的长度。一遍manachar搞定。
我的程序跑的好慢啊……是不是因为我写的是\(NTT\)而不是\(FFT\)啊……
下面贴代码:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<complex> #define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout) #define maxn 300010 #define mod 998244353 #define mo 1000000007 using namespace std; typedef long long llg; const int g=3; int n,m,L,R[maxn],N,nn,ans,p[maxn]; int a[maxn],b[maxn],qm[maxn],c[maxn]; char s[maxn],h[maxn]; int getint(){ int w=0;bool q=0; char c=getchar(); while((c>'9'||c<'0')&&c!='-') c=getchar(); if(c=='-') c=getchar(),q=1; while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w; } int qpow(int x,int y){ int s=1; while(y){ if(y&1) s=1ll*s*x%mod; x=1ll*x*x%mod; y>>=1; } return s; } void ntt(int *a){ for(int i=0;i<n;i++) if(i<R[i]) swap(a[i],a[R[i]]); for(int i=1;i<n;i<<=1){ int gn=qpow(g,(mod-1)/(i<<1)),x,y; for(int j=0;j<n;j+=(i<<1)){ int g=1; for(int k=0;k<i;k++,g=1ll*g*gn%mod){ x=a[j+k]; y=1ll*g*a[j+i+k]%mod; a[j+k]=x+y; if(a[j+k]>=mod) a[j+k]-=mod; a[j+i+k]=x-y; if(x<y) a[j+i+k]+=mod; } } } } void work(){ m=(nn-1)<<1; L=0; for(n=1;n<=m;n<<=1) L++; for(int i=0;i<n;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)); ntt(a); ntt(b); N=qpow(n,mod-2); for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod; ntt(a); reverse(a+1,a+n); for(int i=0;i<n;i++)c[i]+=1ll*a[i]*N%mod,a[i]=b[i]=0;; } int main(){ File("a"); scanf("%s",s); nn=n=strlen(s); qm[0]=1; for(int i=1;i<(n<<1);i++) (qm[i]=qm[i-1]<<1)%=mo; for(int i=0;i<nn;i++) a[i]=b[i]=(s[i]=='a'); work(); for(int i=0;i<nn;i++) a[i]=b[i]=(s[i]=='b'); work(); for(int i=0;i<n;i++) ans+=(qm[(c[i]+1)>>1]-1),ans%=mo; n=nn; n<<=1; h[n]=101; h[0]=100; for(int i=0;i<nn;i++) h[i<<1|1]=s[i]; int mx=0,id=0; for(int i=1;i<n;i++){ if(mx>i) p[i]=min(p[id*2-i],mx-i+1); else p[i]=1; for(;h[i-p[i]]==h[i+p[i]];p[i]++) if(i+p[i]>mx) mx=i+p[i],id=i; } for(int i=1;i<n;i++){ if(!h[i]) p[i]--; ans-=(p[i]+1)>>1,ans+=mo,ans%=mo; } printf("%d",ans); return 0; }