hdoj5769后缀自动机版本
网上的题解都是后缀数组,我来个后缀自动机题解。
建好后缀自动机后由于后缀自动机是单向的,那么dfs一遍记录各节点的size,要保证一个节点只经过一次才是O(n),否则是O(n^2)。表示这个节点及后面还有几个节点。然后再来个ans数组,再dfs一次。这次如果走的是题目要的字母(记c),那么ans[x]+=siz[to],因为to能到的节点对应的子串都有c。如果走的不是c,那么ans[x]+=ans[to]。ans其实就表示该节点以后能有几个能有c的子串。同样ans只能走一次,否则是n^2的。那么开始时memset为-1。走到一个节点就赋为0。这样-1的点就是还需要dfs的,一个点只走一次。具体看两个dfs函数。
#include <iostream> #include <cstdio> #include <cmath> #include <algorithm> #include <vector> #include <iomanip> #include <cstring> #include <map> #include <queue> #include <set> #include <cassert> #include <stack> #include <bitset> #define mkp make_pair using namespace std; const double EPS=1e-8; typedef long long lon; const lon SZ=100010,SSZ=2*SZ,APB=26,INF=0x7FFFFFFF,mod=1000000007; lon cnt,maxlen[SSZ],minlen[SSZ],nex[SSZ][APB]; lon slink[SSZ],siz[SSZ],ans[SSZ]; char dst,ch[SZ]; lon add(lon pre,lon c) { lon z=++cnt; maxlen[z]=maxlen[pre]+1; lon u=pre; for(;u!=-1&&!nex[u][c];u=slink[u]) { nex[u][c]=z; } if(u==-1) { slink[z]=0; minlen[z]=1; } else { lon x=nex[u][c]; if(maxlen[x]==maxlen[u]+1) { slink[z]=x; minlen[z]=maxlen[slink[z]]+1; } else { lon v=++cnt; memcpy(nex[v],nex[x],sizeof(nex[x])); slink[v]=slink[x]; maxlen[v]=maxlen[u]+1; minlen[v]=maxlen[slink[v]]+1; slink[x]=slink[z]=v; minlen[x]=maxlen[slink[x]]+1; minlen[z]=maxlen[slink[z]]+1; for(;u!=-1&&nex[u][c]==x;u=slink[u]) { nex[u][c]=v; } } } return z; } void init() { scanf(" %c",&dst); scanf(" %s",ch+1); lon pre=0; slink[0]=-1; cnt=0; memset(ans,-1,sizeof(ans)); for(lon i=1;ch[i];++i) { pre=add(pre,ch[i]-'a'); } } void dfs1(lon x) { siz[x]=1; for(lon i=0;i<APB;++i) { lon t=nex[x][i]; if(t) { if(!siz[t])dfs1(t); siz[x]+=siz[t]; } } } void dfs2(lon x) { ans[x]=0; for(lon i=0;i<APB;++i) { lon t=nex[x][i]; if(t) { if(ans[t]==-1)dfs2(t); if(i==dst-'a')ans[x]+=siz[t]; else ans[x]+=ans[t]; } } } void work() { dfs1(0); dfs2(0); cout<<ans[0]<<endl; for(int i=0;i<=cnt;++i) { memset(nex[i],0,sizeof(nex[i])); siz[i]=0; } } int main() { //std::ios::sync_with_stdio(0); //freopen("d:\\1.txt","r",stdin); lon casenum; cin>>casenum; //cout<<casenum<<endl; for(lon time=1;time<=casenum;++time) //for(lon time=1;cin>>n>>len>>wid;++time) { cout<<"Case #"<<time<<": "; init(); work(); } return 0; }