【BZOJ4259】残缺的字符串

【BZOJ4259】残缺的字符串

Description

  很久很久以前,在你刚刚学习字符串匹配的时候,有两个仅包含小写字母的字符串A和B,其中A串长度为m,B串长度为n。可当你现在再次碰到这两个串时,这两个串已经老化了,每个串都有不同程度的残缺。
  你想对这两个串重新进行匹配,其中A为模板串,那么现在问题来了,请回答,对于B的每一个位置i,从这个位置开始连续m个字符形成的子串是否可能与A串完全匹配?

Input

  第一行包含两个正整数m,n(1<=m<=n<=300000),分别表示A串和B串的长度。
  第二行为一个长度为m的字符串A。
  第三行为一个长度为n的字符串B。
  两个串均仅由小写字母和号组成,其中号表示相应位置已经残缺。

Output

  第一行包含一个整数k,表示B串中可以完全匹配A串的位置个数。
  若k>0,则第二行输出k个正整数,从小到大依次输出每个可以匹配的开头位置(下标从1开始)。

Sample Input

3 7
ab
aebr
ob

Sample Output

2
1 5

首先带通配符的字符串匹配好像不能\(kmp\)

这是\(NTT/FFT\)的一个经典应用。

如果\(A[i]与B[j]\)不能匹配,那么\(j-i+1\)就不能作为匹配的开头位置。

所以我们设一个函数\(\displaystyle GG(x)=\sum_{i=1}^n\sum_{j=1}^m[j-i+1==x]\cdot [A[i]与B[j]不能匹配]\)。我们发现这个函数有点像一个卷积的形式。于是我们将第一个字符串翻转(因为是\(-i\)),然后关键在于怎么构造卷积来使得不同的字符对\(GG\)函数有贡献。

我们将通配符位置的值设为0,其他的设为其在字符表中的序号。然后

\[\begin{align} \displaystyle GG(x)&=\sum_{i=1}^n\sum_{j=1}^m[j-i+1==x]\cdot (A[i]-B[j])^2A[i]B[j]\\ &=\sum_{i=1}^n\sum_{j=1}^m[j-i+1==x]\cdot(A[i]^3B[j]-2A[i]^2B[j]^2+A[i]B[j]^3) \end{align} \]

然后我们做3次FFT就可以了。

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 300005
#define Z complex<double>
#define pi acos(-1)
#define mod 998244353

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

int n,m;
char s[N],t[N];
int x[N],y[N];
Z f[N<<2],g[N<<2];
int rev[N<<2];
void FFT(Z *a,int d,int flag) {
    int n=1<<d;
    for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
    for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int s=1;s<=d;s++) {
        int len=1<<s,mid=len>>1;
        Z w(cos(2*pi*flag/len),sin(2*pi*flag/len));
        for(int i=0;i<n;i+=len) {
            Z t(1,0);
            for(int j=0;j<mid;j++,t*=w) {
                Z u=a[i+j],v=t*a[i+j+mid];
                a[i+j]=u+v;
                a[i+j+mid]=u-v;
            }
        }
    }
    if(flag==-1) for(int i=0;i<n;i++) a[i]/=n;
}
int Match[N<<2];
void solve(int d,int flag) {
    FFT(f,d,1),FFT(g,d,1);
    for(int i=0;i<(1<<d);i++) f[i]*=g[i];
    FFT(f,d,-1);
    for(int i=0;i<(1<<d);i++) Match[i]+=flag*(ll)(f[i].real()+0.5);
}

vector<int>ans;
ll cal2(ll a) {return a*a;}
ll cal3(ll a) {return a*a*a;}

int main() {
    n=Get(),m=Get();
    scanf("%s",s);
    scanf("%s",t);
    reverse(s,s+n);
    
    for(int i=0;i<n;i++) x[i]=s[i]=='*'?0:s[i]-'a'+1;
    for(int i=0;i<m;i++) y[i]=t[i]=='*'?0:t[i]-'a'+1;
    int d=ceil(log2(m+n));
    
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    for(int i=0;i<n;i++) f[i]=Z(cal3(x[i]),0);
    for(int i=0;i<m;i++) g[i]=Z(y[i],0);
    solve(d,1);
    
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    for(int i=0;i<n;i++) f[i]=Z(x[i],0);
    for(int i=0;i<m;i++) g[i]=Z(cal3(y[i]),0);
    solve(d,1);
    
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    for(int i=0;i<n;i++) f[i]=Z(cal2(x[i]),0);
    for(int i=0;i<m;i++) g[i]=Z(cal2(y[i]),0);
    solve(d,-2);

    for(int i=0;i<m+n;i++) if(Match[i]==0&&1<=i-n+2&&i-n+2<=m-n+1) ans.push_back(i-n+2);
    cout<<ans.size()<<"\n";
    
	for(int i=0;i<ans.size();i++) cout<<ans[i]<<" ";
    return 0;
}

posted @ 2018-12-01 08:56  hec0411  阅读(196)  评论(0编辑  收藏  举报