BZOJ4503: 两个串
Description
兔子们在玩两个串的游戏。给定两个字符串S和T,兔子们想知道T在S中出现了几次,
分别在哪些位置出现。注意T中可能有“?”字符,这个字符可以匹配任何字符。
Input
两行两个字符串,分别代表S和T
Output
第一行一个正整数k,表示T在S中出现了几次
接下来k行正整数,分别代表T每次在S中出现的开始位置。按照从小到大的顺序输出,S下标从0开始。
Sample Input
ababcadaca
a?a
a?a
Sample Output
3
0
5
0
5
HINT
S 长度不超过 10^5, T 长度不会超过 S。 S 中只包含小写字母, T中只包含小写字母和“?”
这种通配符匹配问题一般就是多项式乘法。
把T串倒过来,把a到z变成1到26,'?'变成0,设计一个函数F(x)=∑(A[i]-B[j])^2*A[i]*B[j],这样对于两个位置i,j,恰好满足题意(A[i]=A[j]或A[i]、B[j]中至少有一个?)。
拆一下F(x)=∑A[i]^3 * B[j] + B[j]^3 * A[i] - 2*A[i]*A[i] * B[j]*B[j]
三次NTT乘法。
#include<cstdio> #include<cctype> #include<queue> #include<cstring> #include<algorithm> #define rep(i,s,t) for(int i=s;i<=t;i++) #define dwn(i,s,t) for(int i=s;i>=t;i--) #define ren for(int i=first[x];i;i=next[i]) using namespace std; inline int read() { int x=0,f=1;char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } typedef long long ll; const int G=3; const int p=998244353; const int maxn=270010; int pow(ll n,int m) { ll ans=1; for(;m;(n*=n)%=p,m>>=1) if(m&1) (ans*=n)%=p; return ans; } int wn[20]; void NTT(int* A,int len,int tp) { int j=len>>1,c=0; rep(i,1,len-2) { if(i<j) swap(A[i],A[j]);int k=len>>1; while(j>=k) j-=k,k>>=1;j+=k; } for(int i=2;i<=len;i<<=1) { c++;for(int j=0;j<len;j+=i) { int w=1; for(int k=j;k<j+(i>>1);k++) { int u=A[k],t=(ll)A[k+(i>>1)]*w%p; A[k]=(u+t)%p;A[k+(i>>1)]=(u-t+p)%p; w=(ll)w*wn[c]%p; } } } if(tp<0) { int inv=pow(len,p-2); rep(i,1,len/2-1) swap(A[i],A[len-i]); rep(i,0,len-1) A[i]=(ll)A[i]*inv%p; } } char s[maxn],t[maxn]; int A[maxn],B[maxn],ans[maxn]; int main() { rep(i,1,19) wn[i]=pow(G,(p-1)/(1<<i)); scanf("%s%s",s,t); int n=strlen(s),m=strlen(t),len=1;while(len<(n<<1)||len<(m<<1)) len<<=1; rep(i,0,n-1) s[i]-='a'-1; rep(i,0,m-1) t[i]=((t[i]=='?')?0:t[i]-'a'+1); //A[i]^3 * B[j] rep(i,0,n-1) A[i]=s[i]*s[i]*s[i]; rep(i,0,m-1) B[m-i-1]=t[i]; NTT(A,len,1);NTT(B,len,1); rep(i,0,len-1) A[i]=(ll)A[i]*B[i]%p; NTT(A,len,-1); rep(i,0,len-1) (ans[i]+=A[i])%=p; //B[j]^3 * A[i] memset(A,0,sizeof(A));memset(B,0,sizeof(B)); rep(i,0,n-1) A[i]=s[i]; rep(i,0,m-1) B[m-i-1]=t[i]*t[i]*t[i]; NTT(A,len,1);NTT(B,len,1); rep(i,0,len-1) A[i]=(ll)A[i]*B[i]%p; NTT(A,len,-1); rep(i,0,len-1) (ans[i]+=A[i])%=p; //- 2*A[i]*A[i] * B[j]*B[j] memset(A,0,sizeof(A));memset(B,0,sizeof(B)); rep(i,0,n-1) A[i]=2*s[i]*s[i]; rep(i,0,m-1) B[m-i-1]=t[i]*t[i]; NTT(A,len,1);NTT(B,len,1); rep(i,0,len-1) A[i]=(ll)A[i]*B[i]%p; NTT(A,len,-1); int res=0; rep(i,m-1,n-1) if(ans[i]==A[i]) res++; printf("%d\n",res); rep(i,m-1,n-1) if(ans[i]==A[i]) printf("%d\n",i-m+1); return 0; }