BZOJ 4503 两个串 ——FFT

【题目分析】

    定义两个字符之间的距离为

    (ai-bi)^2*ai*bi

    如果能够匹配,从i到i+m的位置的和一定为0

    但这和暴力没有什么区别。

    发现把b字符串反过来就可以卷积用FFT了。

    听说KMP+暴力可以卡到100ms以内(雾)

【代码】

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

#define maxn 400005
#define F(i,j,k) for (int i=j;i<=k;++i)
#define D(i,j,k) for (int i=j;i>=k;--i)

const double pi=acos(-1.0);
const double eps=1e-6;

struct Complex{
    double x,y;
    Complex operator + (Complex a) const{Complex b; return b.x=x+a.x,b.y=y+a.y,b;}
    Complex operator - (Complex a) const{Complex b; return b.x=x-a.x,b.y=y-a.y,b;}
    Complex operator * (Complex a) const{Complex b; return b.x=x*a.x-y*a.y,b.y=x*a.y+y*a.x,b;}
}a[maxn],b[maxn],c[maxn]; 

char s[maxn],t[maxn];
int A[maxn],B[maxn];
int ls,lt,rev[maxn],n,m=1,len,ans[maxn];

void FFT(Complex * x, int n, int f)
{
	F(i,0,n-1) if (rev[i]>i) swap(x[rev[i]],x[i]);
	for (int m=2;m<=n;m<<=1)
	{
		int mid=m>>1;Complex wn; wn.x=cos(2.0*pi/m*f); wn.y=sin(2.0*pi/m*f);
		for (int i=0;i<n;i+=m)
		{
			Complex w; w.x=1.0; w.y=0;
			F(j,0,mid-1)
			{
				Complex a=x[i+j],b=x[i+j+mid]*w;
				x[i+j]=a+b; x[i+j+mid]=a-b;
				w=w*wn;
			}
		}
	}
}

int main()
{
//	freopen("in.txt","r",stdin);
	scanf("%s",s); scanf("%s",t);
	ls=strlen(s); lt=strlen(t);
	n=ls+lt+1;
	while (m<=n) m<<=1,len++; n=m;
	F(i,0,n-1)
	{
		int t=i,ret=0;
		F(j,1,len) ret<<=1,ret|=t&1,t>>=1;
		rev[i]=ret;
	}
	F(i,0,ls-1) A[i]=s[i]-'a'+1;
	F(i,0,lt-1) {B[lt-i-1]=t[i]-'a'+1;if (t[i]=='?') B[lt-i-1]=0;}
	
	memset(a,0,sizeof a);
	memset(b,0,sizeof b);
	F(i,0,n-1) a[i].x=1;
	F(i,0,n-1) b[i].x=B[i]*B[i]*B[i];
	FFT(a,n,1); FFT(b,n,1);
	F(i,0,n-1)
		c[i]=a[i]*b[i];
	
	memset(a,0,sizeof a);
	memset(b,0,sizeof b);
	F(i,0,n-1) a[i].x=2*A[i];
	F(i,0,n-1) b[i].x=B[i]*B[i];
	FFT(a,n,1); FFT(b,n,1);
	F(i,0,n-1)
		c[i]=c[i]-a[i]*b[i];
	
	memset(a,0,sizeof a);
	memset(b,0,sizeof b);
	F(i,0,n-1) a[i].x=A[i]*A[i];
	F(i,0,n-1) b[i].x=B[i];
	FFT(a,n,1); FFT(b,n,1);
	F(i,0,n-1) c[i]=c[i]+a[i]*b[i];
	
	FFT(c,n,-1);
	F(i,0,n-1) c[i].x=c[i].x/n;
	int cnt=0;
	F(i,0,ls-lt) if (c[i+lt-1].x<0.5) cnt++;
	printf("%d\n",cnt);
	F(i,0,ls-lt) if (c[i+lt-1].x<0.5) printf("%d\n",i);
}

  

posted @ 2017-02-13 23:40  SfailSth  阅读(132)  评论(0编辑  收藏  举报