BZOJ4892: [Tjoi2017]dna

这题虽然随便用啥方法求个LCP就完事了,但是显然也可以FFT,并且FFT可以允许任意个字符不同,当然缺点是字符集必须足够小。把第二个串倒过来后,对于每种字符,把出现的位置设为1,其他设为0,就可以用卷积求出所有位置该字符的匹配个数,最后把所有字符的结果加起来即可。

很久以前我就这么做了,然而并没有卡过去,今天突然看到了这题,又卡了一波就卡过去了。首先使用“1.5次FFT”的优化,然后可以让两种字符一起匹配。具体而言,一种字符的位置设为1,另一种设为m+1,这样对于卷积后的一项s,第一种字符的匹配个数是$s\bmod(m+1)$,第二种的匹配个数是$\lfloor s/(m+1)^2\rfloor$。

#include<bits/stdc++.h>
using namespace std;
const int N=1<<16;
typedef long long ll;
typedef double flo;
const flo pi=acos(-1.);
struct vec{
	flo x,y;
	vec operator+(const vec&b)const{return{x+b.x,y+b.y};}
	vec operator-(const vec&b)const{return{x-b.x,y-b.y};}
	vec operator*(const vec&b)const{return{x*b.x-y*b.y,x*b.y+y*b.x};}
	vec operator+(flo b)const{return{x+b,y};}
	vec operator*(flo b)const{return{x*b,y*b};}
};
vec conj(const vec&b){return{b.x,-b.y};}
vec a[N],b[N],c[N],w[N/2];
void fft(vec*a,int n){
	for(int i=0,j=0;i<n;++i){
		if(i<j)
			swap(a[i],a[j]);
		int k=n>>1;
		while((j^=k)<k)
			k>>=1;
	}
	w[0]={1};
	for(int i=1;i<n;i<<=1){
		for(int j=i-2;j>0;j-=2)
			w[j]=w[j>>1];
		vec s={cos(pi/i),sin(pi/i)};
		for(int j=1;j<i;j+=2)
			w[j]=s*w[j-1];
		for(int j=0;j<n;j+=i<<1){
			vec*b=a+j,*c=b+i;
			for(int k=0;k<i;++k){
				vec v=w[k]*c[k];
				c[k]=b[k]-v,b[k]=b[k]+v;
			}
		}
	}
}
int q,m,l,r[N*2];
char u[N*2],v[N*2];
int cal(char a,const char*f){
	return a==f[0]?1:a==f[1]?m+1:0;
}
void sol(const char*f){
	for(int i=0;i<l<<1;++i){
		(i&1?a[i>>1].y:a[i>>1].x)=cal(u[i],f);
		(i&1?b[i>>1].y:b[i>>1].x)=cal(v[i],f);
	}
	fft(a,l);
	fft(b,l);
	for(int i=0;i<l;++i){
		int j=l-i&l-1;
		c[j]=vec({0,.25})*(conj(a[j]*b[j])*4-(conj(a[j])-a[i])*(conj(b[j])-b[i])*((i<l/2?w[i]:w[i-l/2]*-1)+1));
	}
	fft(c,l);
	ll t=(ll)(m+1)*(m+1);
	for(int i=0;i<l<<1;++i){
		ll s=(i&1?c[i>>1].x:c[i>>1].y)/l+.5;
		r[i]+=s%(m+1)+s/t;
	}
}
int main(){
	scanf("%d",&q);
	while(q--){
		scanf("%s%s",u,v);
		int n=strlen(u);
		m=strlen(v);
		if(m>n)
			puts("0");
		else if(m<=3)
			printf("%d\n",n-m+1);
		else{
			reverse(v,v+m);
			l=1<<__lg(n*2-1);
			fill(u+n,u+l,0);
			fill(v+m,v+l,0);
			fill(r,r+l,0);
			l>>=1;
			sol("AT");
			sol("CG");
			int s=0;
			for(int i=m-1;i<n;++i)
				s+=m-r[i]<=3;
			printf("%d\n",s);
		}
	}
}
posted @ 2018-08-08 16:12  f321dd  阅读(280)  评论(0编辑  收藏  举报