【BZOJ4259】残缺的字符串
【BZOJ4259】残缺的字符串
Description
很久很久以前,在你刚刚学习字符串匹配的时候,有两个仅包含小写字母的字符串A和B,其中A串长度为m,B串长度为n。可当你现在再次碰到这两个串时,这两个串已经老化了,每个串都有不同程度的残缺。
你想对这两个串重新进行匹配,其中A为模板串,那么现在问题来了,请回答,对于B的每一个位置i,从这个位置开始连续m个字符形成的子串是否可能与A串完全匹配?
Input
第一行包含两个正整数m,n(1<=m<=n<=300000),分别表示A串和B串的长度。
第二行为一个长度为m的字符串A。
第三行为一个长度为n的字符串B。
两个串均仅由小写字母和号组成,其中号表示相应位置已经残缺。
Output
第一行包含一个整数k,表示B串中可以完全匹配A串的位置个数。
若k>0,则第二行输出k个正整数,从小到大依次输出每个可以匹配的开头位置(下标从1开始)。
Sample Input
3 7
ab
aebrob
Sample Output
2
1 5
首先带通配符的字符串匹配好像不能\(kmp\)。
这是\(NTT/FFT\)的一个经典应用。
如果\(A[i]与B[j]\)不能匹配,那么\(j-i+1\)就不能作为匹配的开头位置。
所以我们设一个函数\(\displaystyle GG(x)=\sum_{i=1}^n\sum_{j=1}^m[j-i+1==x]\cdot [A[i]与B[j]不能匹配]\)。我们发现这个函数有点像一个卷积的形式。于是我们将第一个字符串翻转(因为是\(-i\)),然后关键在于怎么构造卷积来使得不同的字符对\(GG\)函数有贡献。
我们将通配符位置的值设为0,其他的设为其在字符表中的序号。然后
\[\begin{align}
\displaystyle GG(x)&=\sum_{i=1}^n\sum_{j=1}^m[j-i+1==x]\cdot (A[i]-B[j])^2A[i]B[j]\\
&=\sum_{i=1}^n\sum_{j=1}^m[j-i+1==x]\cdot(A[i]^3B[j]-2A[i]^2B[j]^2+A[i]B[j]^3)
\end{align}
\]
然后我们做3次FFT就可以了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 300005
#define Z complex<double>
#define pi acos(-1)
#define mod 998244353
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
int n,m;
char s[N],t[N];
int x[N],y[N];
Z f[N<<2],g[N<<2];
int rev[N<<2];
void FFT(Z *a,int d,int flag) {
int n=1<<d;
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int s=1;s<=d;s++) {
int len=1<<s,mid=len>>1;
Z w(cos(2*pi*flag/len),sin(2*pi*flag/len));
for(int i=0;i<n;i+=len) {
Z t(1,0);
for(int j=0;j<mid;j++,t*=w) {
Z u=a[i+j],v=t*a[i+j+mid];
a[i+j]=u+v;
a[i+j+mid]=u-v;
}
}
}
if(flag==-1) for(int i=0;i<n;i++) a[i]/=n;
}
int Match[N<<2];
void solve(int d,int flag) {
FFT(f,d,1),FFT(g,d,1);
for(int i=0;i<(1<<d);i++) f[i]*=g[i];
FFT(f,d,-1);
for(int i=0;i<(1<<d);i++) Match[i]+=flag*(ll)(f[i].real()+0.5);
}
vector<int>ans;
ll cal2(ll a) {return a*a;}
ll cal3(ll a) {return a*a*a;}
int main() {
n=Get(),m=Get();
scanf("%s",s);
scanf("%s",t);
reverse(s,s+n);
for(int i=0;i<n;i++) x[i]=s[i]=='*'?0:s[i]-'a'+1;
for(int i=0;i<m;i++) y[i]=t[i]=='*'?0:t[i]-'a'+1;
int d=ceil(log2(m+n));
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
for(int i=0;i<n;i++) f[i]=Z(cal3(x[i]),0);
for(int i=0;i<m;i++) g[i]=Z(y[i],0);
solve(d,1);
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
for(int i=0;i<n;i++) f[i]=Z(x[i],0);
for(int i=0;i<m;i++) g[i]=Z(cal3(y[i]),0);
solve(d,1);
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
for(int i=0;i<n;i++) f[i]=Z(cal2(x[i]),0);
for(int i=0;i<m;i++) g[i]=Z(cal2(y[i]),0);
solve(d,-2);
for(int i=0;i<m+n;i++) if(Match[i]==0&&1<=i-n+2&&i-n+2<=m-n+1) ans.push_back(i-n+2);
cout<<ans.size()<<"\n";
for(int i=0;i<ans.size();i++) cout<<ans[i]<<" ";
return 0;
}