残缺的字符串
我终于会用多项式来做字符串匹配了
我们设一个匹配函数\(f(x)\)表示\(A\)串的第\(x\)位能否匹配
我们把原字符串上的每一个字母都改成数字,比如\('a'\)变成\(1\),\('b'\)变成\(2\)
于是我们定义
\[f(x)=\sum_{i=1}^m(A_{i+x-1}-B_i)^2
\]
我们考虑一下上面的那个柿子,只有当\(A_{i+x-1}\)始终等于\(B_i\)的时候,\(f(x)=0\),只要有一位上是不一样的,\(f(x)>0\)
也就是这一位匹配上了就是\(0\),没匹配上技就是大于\(0\)
我们考虑加入通配符,只需要强行使得这一位是\(0\)就好了
我们强行定义通配符为\(0\)
于是现在
\[f(x)=\sum_{i=1}^m(A_{i+x-1}-B_i)^2A_{i+x-1}B_i
\]
拆开就是三个柿子
\[\sum_{i=1}^mA_{i+x-1}^3B_i
\]
\[\sum_{i=1}^mA_{i+x-1}B_i^3
\]
\[-2\sum_{i=1}^mA_{i+x-1}^2B_i^2
\]
我们发现我们翻转一下\(B\)串就能变成卷积的形式了
于是三遍\(ntt\)之后输出所有为\(0\)的下标就好了
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
const int maxn=2e6+5;
const int G[2]={3,332748118};
int n,m,len,rev[maxn],ans;
int a[maxn],b[maxn];
int a1[maxn],a2[maxn],a3[maxn];
int b1[maxn],b2[maxn],b3[maxn];
char S[maxn],T[maxn];
const int mod=998244353;
inline int ksm(int a,int b) {
int S=1;
while(b) {if(b&1) S=1ll*S*a%mod;b>>=1;a=1ll*a*a%mod;}
return S;
}
inline void NTT(int *f,int o) {
for(re int i=0;i<len;i++) if(i<rev[i]) std::swap(f[i],f[rev[i]]);
for(re int i=2;i<=len;i<<=1) {
int ln=i>>1,og1=ksm(G[o],(mod-1)/i);
for(re int l=0;l<len;l+=i) {
int t,og=1;
for(re int x=l;x<l+ln;++x) {
t=1ll*f[x+ln]*og%mod;
f[x+ln]=(f[x]-t+mod)%mod;
f[x]=(f[x]+t)%mod;
og=1ll*og*og1%mod;
}
}
}
if(!o) return;
int inv=ksm(len,mod-2);
for(re int i=0;i<len;i++) f[i]=1ll*inv*f[i]%mod;
}
int main() {
scanf("%d%d",&m,&n);
scanf("%s",S+1);scanf("%s",T+1);
for(re int i=1;i<=m;i++) {
if(S[i]=='*') {b[m-i+1]=0;continue;}
b[m-i+1]=S[i]-'a'+1;
}
for(re int i=1;i<=n;i++) {
if(T[i]=='*') {a[i]=0;continue;}
a[i]=T[i]-'a'+1;
}
len=1;while(len<=n+m) len<<=1;
for(re int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
for(re int i=1;i<=n;i++) a1[i]=a[i]*a[i]*a[i];
for(re int i=1;i<=n;i++) a2[i]=a[i];
for(re int i=1;i<=n;i++) a3[i]=a[i]*a[i];
for(re int i=1;i<=m;i++) b1[i]=b[i];
for(re int i=1;i<=m;i++) b2[i]=b[i]*b[i]*b[i];
for(re int i=1;i<=m;i++) b3[i]=2*b[i]*b[i];
NTT(a1,0),NTT(a2,0),NTT(a3,0);
NTT(b1,0),NTT(b2,0),NTT(b3,0);
for(re int i=0;i<len;i++)
b1[i]=(1ll*b1[i]*a1[i])%mod,
b2[i]=(1ll*b2[i]*a2[i])%mod,
b3[i]=(1ll*b3[i]*a3[i])%mod;
NTT(b1,1),NTT(b2,1),NTT(b3,1);
for(re int i=m+1;i<=n+m&&i-1<=n;i++)
b1[i]=(b1[i]+b2[i]-b3[i]+mod)%mod,ans+=(b1[i]==0);
printf("%d\n",ans);
for(re int i=m+1;i<=n+m&&i-1<=n;i++)
if(!b1[i]) printf("%d ",i-m);
puts("");
return 0;
}