bzoj4503: 两个串
4503: 两个串
链接&&题意
bzoj
兔子们在玩两个串的游戏。给定两个字符串S和T,兔子们想知道T在S中出现了几次,
分别在哪些位置出现。注意T中可能有“?”字符,这个字符可以匹配任何字符。
思路
没有通配符。
\(\sum\limits_{i=1}^{m-1}(S[j+i]-T[i])^2=0\)
旋转T
\(\sum\limits_{i=1}^{m-1}(S[j+i]-T[m-i-1])^2=0\)
有通配符
‘?’=0
\(\sum\limits_{i=1}^{m-1}(S[j+i]-T[m-i-1])^2T[m-1-i]=0\)
展开直接fft就行了
吐槽
本来写的挺详细的,编辑器抽搐了,没自动保存。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=4e5+7;
const double Pi=acos(-1),eps=1e-9;
int n,m,limit=1,r[N],p;
double vis[N];
char S[N],T[N];
int ff(int a) {return a*a;}
int fff(int a) {return a*a*a;}
struct Complex {
double x,y;
Complex(double xx=0,double yy=0) {x=xx,y=yy;}
}a[N],b[N];
Complex operator + (Complex a,Complex b) {return Complex(a.x+b.x,a.y+b.y);}
Complex operator - (Complex a,Complex b) {return Complex(a.x-b.x,a.y-b.y);}
Complex operator * (Complex a,Complex b) {return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
void fft(Complex *a,int type) {
for(int i=0;i<limit;++i)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<limit;mid<<=1) {
Complex Wn(cos(Pi/mid),type*sin(Pi/mid));
for(int i=0;i<limit;i+=(mid<<1)) {
Complex w(1,0);
for(int j=0;j<mid;++j,w=w*Wn) {
Complex x=a[i+j],y=w*a[i+j+mid];
a[i+j]=x+y;
a[i+j+mid]=x-y;
}
}
}
if(type==-1) for(int i=0;i<=limit;++i) a[i].x=a[i].x/limit;
}
void work() {
fft(a,1),fft(b,1);
for(int i=0;i<limit;++i) a[i]=a[i]*b[i];
fft(a,-1);
}
int main() {
scanf("%s%s",S,T);
n=strlen(S),m=strlen(T);
for(int i=0;i<=m/2;++i) swap(T[i],T[m-i-1]);
while(limit<n+m) limit<<=1,p++;
for(int i=0;i<limit;++i)
r[i]=(r[i>>1]>>1)|((i&1)<<(p-1));
double sum=0;for(int i=0;i<m;++i) sum+=fff(T[i]=='?'?0:T[i]-'a'+1);
for(int i=0;i<=limit;++i) a[i]=b[i]=0;
for(int i=0;i<n;++i) a[i].x=ff(S[i]-'a'+1);
for(int i=0;i<m;++i) b[i].x=T[i]=='?'?0:T[i]-'a'+1;
work();
for(int i=0;i<limit;++i) vis[i]=(vis[i]+round(a[i].x));
memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
for(int i=0;i<=limit;++i) a[i]=b[i]=0;
for(int i=0;i<n;++i) a[i].x=(S[i]-'a'+1);
for(int i=0;i<m;++i) b[i].x=ff(T[i]=='?'?0:T[i]-'a'+1);
work();
for(int i=0;i<limit;++i) vis[i]=(vis[i]-2.0*round(a[i].x));
int ans=0;
for(int i=0;i<=n-m;++i) if(fabs(vis[m+i-1]+sum)<eps) ans++;
printf("%d\n",ans);
for(int i=0;i<=n-m;++i) if(fabs(vis[m+i-1]+sum)<eps) printf("%d\n",i);
return 0;
}