BZOJ 4503: 两个串
Description
有通配符的字符串匹配.\(n,m\leqslant 10^5\)
Solution
FFT.
\(D_k=\sum_{i+j=k}(S_i-T_j)^2T_j\)
把他化成这样的式子,这样如果两个位置相等,或者\(T_j\)为\(0\),那么就可以匹配
把通配符设成\(0\)即可
Code
#include <bits/stdc++.h> using namespace std; #define mpr make_pair #define x first #define y second const int N = 4e5+500; const double Pi = M_PI; namespace Pol { typedef pair<double,double> cp; cp operator + (const cp &a,const cp &b) { return mpr(a.x+b.x,a.y+b.y); } cp operator - (const cp &a,const cp &b) { return mpr(a.x-b.x,a.y-b.y); } cp operator * (const cp &a,const cp &b) { return mpr(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x); } int pn; void init(int n) { for(pn=1;pn<n;pn<<=1);pn<<=1; } void Rev(cp a[],int n=pn) { for(int i=0,j=0;i<n;i++) { if(i>j) swap(a[i],a[j]); for(int k=n>>1;(j^=k)<k;k>>=1); } } void DFT(cp a[],int r=1,int n=pn) { Rev(a); for(int i=2;i<=n;i<<=1) { cp wi=mpr(cos(2*Pi/i),r*sin(2*Pi/i)); for(int j=0;j<n;j+=i) { cp w=mpr(1,0); for(int k=j;k<j+i/2;k++) { cp t1=a[k],t2=w*a[k+i/2]; a[k]=t1+t2,a[k+i/2]=t1-t2; w=w*wi; } } }if(!~r) for(int i=0;i<n;i++) a[i].x/=n; } void FFT(cp a[],cp b[],cp c[],int n=pn) { DFT(a,1,n),DFT(b,1,n); for(int i=0;i<n;i++) c[i]=a[i]*b[i]; DFT(c,-1,n); } } inline int in(int x=0,char ch=getchar()) { while(ch>'9'||ch<'0') ch=getchar(); while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();return x; } using namespace Pol; int n,m; char s[N],t[N]; double a[N],b[N]; cp s0[N],s1[N],s2[N],t1[N],t2[N],t3[N]; cp s2t[N],st2[N],s0t3[N]; int main() { scanf("%s%s",s,t); n=strlen(s),m=strlen(t); reverse(t,t+m); init(max(n,m)); for(int i=0;i<n;i++) a[i]=s[i]-'a'+1; for(int i=0;i<m;i++) b[i]=t[i]=='?'?0:t[i]-'a'+1; for(int i=0;i<pn;i++) s0[i].x=1,s1[i].x=a[i],s2[i].x=a[i]*a[i], t1[i].x=b[i],t2[i].x=b[i]*b[i],t3[i].x=b[i]*b[i]*b[i]; DFT(s0,1),DFT(s1,1),DFT(s2,1),DFT(t1,1),DFT(t2,1),DFT(t3,1); for(int i=0;i<pn;i++) s2t[i]=s2[i]*t1[i],st2[i]=s1[i]*t2[i],s0t3[i]=s0[i]*t3[i]; DFT(s2t,-1),DFT(st2,-1),DFT(s0t3,-1); // for(int i=0;i<pn;i++) cout<<(s2t[i].x-2*st2[i].x+s0t3[i].x)<<" ";cout<<endl; int ans=0; for(int i=m-1;i<n;i++) if((int)(s2t[i].x-2*st2[i].x+s0t3[i].x+0.5)==0) ans++; printf("%d\n",ans); for(int i=m-1;i<n;i++) if((int)(s2t[i].x-2*st2[i].x+s0t3[i].x+0.5)==0) printf("%d ",i-m+1); return 0; }