【BZOJ4259】残缺的字符串-FFT
测试地址:残缺的字符串
题目大意:给定两个带通配符的字符串,问在中能匹配的所有位置。
做法:本题需要用到FFT。
这题看上去是一个字符串题,然而KMP算法在带通配符的字符串匹配中并不好使,这时候就要分析一下两个带通配符的串匹配的条件。
两个字符匹配当且仅当两个字符相同或其中一个为通配符,如果令通配符为,那么我们可以定义两个字符的差异值为:
可以看出,两个字符匹配当且仅当它们的差异值为,否则差异值总是正数。那么对于两个带通配符的字符串,我们可以求出:
那么两个字符串能够匹配当且仅当这个式子的值为。于是我们先把翻转,然后把上式展开,发现上式是三个卷积形式的式子的和,那么用三次FFT就可以求解出来了。
吐槽:BZOJ的数据真的非常严格……本蒟蒻经历了WA->TLE->WA->AC的艰辛历程,好不容易才过了,我还是太弱了T_T……
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const double eps=1.0;
const double pi=acos(-1.0);
int m,n,r[1200010],bit,x;
char A[300010],B[300010];
double f[300010][3],g[300010][3],val[300010],ans[1200010]={0};
struct Complex
{
double x,y;
}a[1200010],b[1200010];
Complex operator + (Complex a,Complex b) {Complex s={a.x+b.x,a.y+b.y};return s;}
Complex operator - (Complex a,Complex b) {Complex s={a.x-b.x,a.y-b.y};return s;}
Complex operator * (Complex a,Complex b) {Complex s={a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};return s;}
void FFT(Complex *a,int n,int type)
{
for(int i=0;i<n;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1)
{
Complex W={cos(pi/mid),type*sin(pi/mid)};
for(int l=0;l<n;l+=(mid<<1))
{
Complex w={1.0,0.0};
for(int k=0;k<mid;k++,w=w*W)
{
Complex x=a[l+k],y=w*a[l+mid+k];
a[l+k]=x+y;
a[l+mid+k]=x-y;
}
}
}
if (type==-1)
{
for(int i=0;i<n;i++)
a[i].x/=(double)n;
}
}
void calc_f()
{
for(int mode=0;mode<3;mode++)
{
for(int x=0;x<m;x++)
{
if (A[x]=='*') {f[x][mode]=0.0;continue;}
if (mode==0) f[x][mode]=(double)(A[x]*A[x]*A[x]);
if (mode==1) f[x][mode]=(double)(A[x]*A[x]);
if (mode==2) f[x][mode]=(double)(A[x]);
}
}
}
void calc_g()
{
for(int mode=0;mode<3;mode++)
{
for(int x=0;x<n;x++)
{
if (B[x]=='*') {g[x][mode]=0.0;continue;}
if (mode==0) g[x][mode]=(double)(B[x]);
if (mode==1) g[x][mode]=(double)(B[x]*B[x]);
if (mode==2) g[x][mode]=(double)(B[x]*B[x]*B[x]);
}
}
}
double calc_val()
{
val[0]=val[2]=1.0;
val[1]=-2.0;
}
void solve(int mode)
{
memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
for(int i=0;i<m;i++)
a[i].x=f[m-1-i][mode];
for(int i=0;i<n;i++)
b[i].x=g[i][mode];
FFT(a,x,1),FFT(b,x,1);
for(int i=0;i<x;i++)
a[i]=a[i]*b[i];
FFT(a,x,-1);
for(int i=0;i<x;i++)
ans[i]+=val[mode]*a[i].x;
}
int main()
{
scanf("%d%d",&m,&n);
scanf("%s",A);
scanf("%s",B);
bit=0,x=1;
while(x<n+m) x<<=1,bit++;
r[0]=0;
for(int i=1;i<=x;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
calc_f(),calc_g(),calc_val();
solve(0);
solve(1);
solve(2);
int cnt=0;
for(int i=0;i<=n-m;i++)
if (fabs(ans[i+m-1])<eps) cnt++;
printf("%d\n",cnt);
for(int i=0;i<=n-m;i++)
if (fabs(ans[i+m-1])<eps) printf("%d ",i+1);
return 0;
}