BZOJ 3160 万径人踪灭

题目链接:万径人踪灭

  因为manachar写挂导致这道题调了好久……整个人都不好了……

  我们可以发现我们要求回文子序列的数目,并且要求不连续。那么我们显然可以用所有的数目减去连续的数目。

  然后,我们可以非常轻易地发现(我怎么看不出来)如果我们只考虑一个字母的话,把这个串看成一个多项式,平方一下,第\(x\)项的系数就是以第\(x\)个位置(两个字母之间的间隙也算一个位置)为对称轴的对称的字母个数。题中给出的串只有\(a\)、\(b\)两种字母,那么做两遍就得到了关于每个位置\(x\)对称的字母数\(c_x\)。

  由于我们选的是子序列,那么肯定每个位置都可选可不选(全部不选除外)。因此,这一部分的答案就是\(\sum2^{c_x}-1\)。

  然后我们要减去不合法的情况。就是对每个位置求出最长回文串的长度。一遍manachar搞定。

  我的程序跑的好慢啊……是不是因为我写的是\(NTT\)而不是\(FFT\)啊……

  下面贴代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<complex>
#define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
#define maxn 300010
#define mod 998244353
#define mo 1000000007
 
using namespace std;
typedef long long llg;
 
const int g=3;
int n,m,L,R[maxn],N,nn,ans,p[maxn];
int a[maxn],b[maxn],qm[maxn],c[maxn];
char s[maxn],h[maxn];
 
int getint(){
	int w=0;bool q=0;
	char c=getchar();
	while((c>'9'||c<'0')&&c!='-') c=getchar();
	if(c=='-') c=getchar(),q=1;
	while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar();
	return q?-w:w;
}
 
int qpow(int x,int y){
	int s=1;
	while(y){
		if(y&1) s=1ll*s*x%mod;
		x=1ll*x*x%mod; y>>=1;
	}
	return s;
}
 
void ntt(int *a){
	for(int i=0;i<n;i++) if(i<R[i]) swap(a[i],a[R[i]]);
	for(int i=1;i<n;i<<=1){
		int gn=qpow(g,(mod-1)/(i<<1)),x,y;
		for(int j=0;j<n;j+=(i<<1)){
			int g=1;
			for(int k=0;k<i;k++,g=1ll*g*gn%mod){
				x=a[j+k]; y=1ll*g*a[j+i+k]%mod;
				a[j+k]=x+y; if(a[j+k]>=mod) a[j+k]-=mod;
				a[j+i+k]=x-y; if(x<y) a[j+i+k]+=mod;
			}
		}
	}
}

void work(){
	m=(nn-1)<<1; L=0; for(n=1;n<=m;n<<=1) L++;
	for(int i=0;i<n;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	ntt(a); ntt(b); N=qpow(n,mod-2);
	for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
	ntt(a); reverse(a+1,a+n);
	for(int i=0;i<n;i++)c[i]+=1ll*a[i]*N%mod,a[i]=b[i]=0;;
}

int main(){
	File("a");
	scanf("%s",s); nn=n=strlen(s); qm[0]=1;
	for(int i=1;i<(n<<1);i++) (qm[i]=qm[i-1]<<1)%=mo;
	for(int i=0;i<nn;i++) a[i]=b[i]=(s[i]=='a'); work();
	for(int i=0;i<nn;i++) a[i]=b[i]=(s[i]=='b'); work();
	for(int i=0;i<n;i++) ans+=(qm[(c[i]+1)>>1]-1),ans%=mo;
	n=nn; n<<=1; h[n]=101; h[0]=100;
	for(int i=0;i<nn;i++) h[i<<1|1]=s[i];
	int mx=0,id=0;
	for(int i=1;i<n;i++){
		if(mx>i) p[i]=min(p[id*2-i],mx-i+1);
		else p[i]=1;
		for(;h[i-p[i]]==h[i+p[i]];p[i]++)
			if(i+p[i]>mx) mx=i+p[i],id=i;
	}
	for(int i=1;i<n;i++){
		if(!h[i]) p[i]--;
		ans-=(p[i]+1)>>1,ans+=mo,ans%=mo;
	}
	printf("%d",ans);
	return 0;
}
posted @ 2017-02-07 19:38  lcf2000  阅读(146)  评论(0编辑  收藏  举报