【bzoj4259】残缺的字符串 FFT
Description
很久很久以前,在你刚刚学习字符串匹配的时候,有两个仅包含小写字母的字符串A和B,其中A串长度为m,B串长度为n。可当你现在再次碰到这两个串时,这两个串已经老化了,每个串都有不同程度的残缺。
你想对这两个串重新进行匹配,其中A为模板串,那么现在问题来了,请回答,对于B的每一个位置i,从这个位置开始连续m个字符形成的子串是否可能与A串完全匹配?
Input
第一行包含两个正整数m,n(1<=m<=n<=300000),分别表示A串和B串的长度。
第二行为一个长度为m的字符串A。
第三行为一个长度为n的字符串B。
两个串均仅由小写字母和号组成,其中号表示相应位置已经残缺。
Output
第一行包含一个整数k,表示B串中可以完全匹配A串的位置个数。
若k>0,则第二行输出k个正整数,从小到大依次输出每个可以匹配的开头位置(下标从1开始)。
Sample Input
3 7
ab
aebrob
Sample Output
2
1 5
Sol
带通配符还要求每一位能不能匹配,那就是FFT题啦。
我们把\(*\)始终视为0,然后冷静分析一下:
因为FFT的卷积是乘积相加的形式,我们无法直接用减法,而是要转化成乘法,所以我们对于原式平方一下,得到:
\(\sum(a_i-b_i)^2=\sum a_i^2*[b_i>0]+\sum b_i^2*[a_i>0]-2*\sum a_ib_i\)
翻转其中一个串之后,就可以FFT了。
Code
#include <bits/stdc++.h>
#define pi acos(-1)
using namespace std;
int i,j,k,n,m,len,a[1048577],b[1048577],ans[1048577],tot;char aa[1048577],bb[1048577];
struct cp
{
double x,y;
cp(double x=0.0,double y=0.0):x(x),y(y){}
cp operator +(const cp a)const{return cp(x+a.x,y+a.y);}
cp operator -(const cp a)const{return cp(x-a.x,y-a.y);}
cp operator *(const cp a)const{return cp(x*a.x-y*a.y,x*a.y+y*a.x);}
}w,wn,t,A[1048577],B[1048577],C[1048577];
void fft(cp *a,int n,int op)
{
for(i=k=0;i<n;i++){if(i>k) swap(a[i],a[k]);for(j=(n>>1);(k^=j)<j;j>>=1);}
for(k=2,wn=cp(cos(2*pi*op/k),sin(2*pi*op/k));k<=n;k<<=1,wn=cp(cos(2*pi*op/k),sin(2*pi*op/k)))
for(i=0,w=cp(1,0);i<n;i+=k,w=cp(1,0)) for(j=0;j<(k>>1);j++,w=w*wn)
t=w*a[i+j+(k>>1)],a[i+j+(k>>1)]=a[i+j]-t,a[i+j]=a[i+j]+t;
if(op==-1) for(int i=0;i<n;i++) a[i].x/=n;
}
int main()
{
for(scanf("%d%d%s%s",&m,&n,aa,bb),i=0,j=m-1;i<j;i++,j--) swap(aa[i],aa[j]);
for(i=0;i<m;i++) if(aa[i]!='*') a[i]=aa[i]-'a'+1;
for(i=0;i<n;i++) if(bb[i]!='*') b[i]=bb[i]-'a'+1;
for(len=1;len<n+m;len<<=1);
for(i=0;i<len;i++) A[i]=cp(a[i]*a[i],0),B[i]=cp(b[i]>0,0);
for(fft(A,len,1),fft(B,len,1),i=0;i<len;i++) C[i]=C[i]+A[i]*B[i];
for(i=0;i<len;i++) A[i]=cp(a[i]>0,0),B[i]=cp(b[i]*b[i],0);
for(fft(A,len,1),fft(B,len,1),i=0;i<len;i++) C[i]=C[i]+A[i]*B[i];
for(i=0;i<len;i++) A[i]=cp(a[i],0),B[i]=cp(b[i],0);
for(fft(A,len,1),fft(B,len,1),i=0;i<len;i++) C[i]=C[i]-A[i]*B[i]*cp(2.0,0.0);
for(fft(C,len,-1),i=m-1;i<n;i++) if(C[i].x<0.1) ans[++tot]=i-m+2;
for(printf("%d\n%d",tot,ans[1]),i=2;i<=tot;i++) printf(" %d",ans[i]);
}