[Lydsy2017省队十连测]航海舰队
SOL: 我们用FFT匹配字符串。
不知道为什么我的NTT挂了,贴一个别人的FFT。
#include<bits/stdc++.h> #define ll long long #define inf 2e9 #define PI acos(-1.0) #define pii pair<int,int> #define fi first #define se second #define mk make_pair using namespace std; const int N=1e6+5; const int MXN=7e2+5; struct cp { double x,y; cp(double _x=0,double _y=0) { x=_x,y=_y;} }A[N],B[N]; cp operator + (cp x,cp y){cp z;z.x=x.x+y.x;z.y=x.y+y.y;return z;} cp operator - (cp x,cp y){cp z;z.x=x.x-y.x;z.y=x.y-y.y;return z;} cp operator * (cp x,cp y){cp z;z.x=x.x*y.x-x.y*y.y;z.y=x.y*y.x+x.x*y.y;return z;} int ans,up=inf,dn,lf=inf,rt,W,H,all,k,len,M,n,m,r[N],T[N]; queue<pii >q; char a[MXN][MXN]; bool vis[N],v[N]; void init(){ for(int i=0;i<n;i++) for(int j=0;j<m;j++){ if(a[i][j]=='#') B[i*m+j]=cp(1,0); else if(a[i][j]=='o'){ up=min(up,i),dn=max(dn,i); lf=min(lf,j),rt=max(rt,j); } } for(int i=up;i<=dn;i++) for(int j=lf;j<=rt;j++) if(a[i][j]=='o')T[(i-up)*m+j-lf]=1; W=rt-lf+1,H=dn-up+1,len=(H-1)*m+W; } void FFT(cp *x,int f){ for(int i=0;i<M;i++) if(r[i]>i) swap(x[r[i]],x[i]); for(int i=1;i<M;i<<=1){ cp wn(cos(PI/i),f*sin(PI/i)); for(int j=0;j<M;j+=i<<1){ cp w=1; for(int k=0;k<i;k++,w=w*wn){ cp a=x[j+k],b=w*x[j+k+i]; x[j+k]=a+b,x[j+k+i]=a-b; } } } if(f==-1) for(int i=0;i<M;i++) x[i].x/=M; } void work(){ q.push(mk(dn,rt)); while(!q.empty()){ int x=q.front().fi,y=q.front().se;q.pop(); if(x<0||x>=n||y<0||y>=m) continue; int z=x*m+y; if(!v[z]||vis[z]) continue; vis[z]=1; q.push(mk(x+1,y)),q.push(mk(x-1,y)),q.push(mk(x,y-1)),q.push(mk(x,y+1)); } } int main(){ cin>>n>>m,all=n*m; for(int i=0;i<n;i++) scanf("%s",a[i]); init(); for(M=1;M<=all;M<<=1,k++); for(int i=1;i<M;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(k-1)); for(int i=0;i<len;i++) if(T[i]) A[len-i-1]=1; FFT(A,1),FFT(B,1); for(int i=0;i<M;i++) A[i]=A[i]*B[i]; FFT(A,-1); for(int i=H-1;i<n;i++) for(int j=W-1;j<m;j++) if(A[i*m+j].x<0.5) v[i*m+j]=1; work(); for(int i=0;i<M;i++) A[i]=T[i]; for(int i=0;i<M;i++) B[i]=vis[i]; FFT(A,1),FFT(B,1); for(int i=0;i<M;i++) A[i]=A[i]*B[i]; FFT(A,-1); for(int i=0;i<M;i++) if(A[i].x>0.5) ans++; return printf("%d\n",ans),0; }