hdu4787 AC自动机加分块
这题说的是 有n次操作 +w 表示读入一个字符串,?p 询问这个字符串的子串在那些模板串中有多少个,
http://blog.csdn.net/qq574857122/article/details/16826631
就是说先存一部分的字符串,因为每次都要进行重新 建立这个失配指针,也就是说让适当的单词进行失配 指针重建 会达到高效,两个ac自动机,选取sqrt(100000)的时候达到相对优一点,这样我们 当第一棵树超过了 800的时候我们就将第一棵树的东西存入第二棵树这样我们可以相对减少对失配指针的重建
#include <iostream> #include <algorithm> #include <string.h> #include <queue> #include <cstdio> using namespace std; const int maxn=600000; struct Aho{ int ch[maxn][2]; int f[maxn]; bool val[maxn]; int last[maxn]; int sz; void init() { sz=1; ch[0][0]=ch[0][1]=0;val[0]=0;last[0]=0; } int idx(char c){ return c=='0'?0:1; } void insert(char *s,int n) { int u=0; for(int i=0; i<n; i++) { int c=idx(s[i]); if(ch[u][c]==0) { ch[sz][0]=ch[sz][1]=0; val[sz]=0; ch[u][c]=sz++; } u=ch[u][c]; } val[u]=true; } bool search(char *s,int n) { int u=0; for(int i=0; i<n; i++) { int c=idx(s[i]); if(ch[u][c]==0)return false; u=ch[u][c]; } return val[u]; } int print(int j) { int ans=0; while(j) { ans++;j=last[j]; } return ans; } int find(char *T,int n) { int j=0; int ans=0; for(int i=0; i<n; i++) { int c=idx(T[i]); while(j&&ch[j][c]==0)j=f[j]; j=ch[j][c]; if(val[j])ans+=print(j); else if(last[j])ans+=print(last[j]); } return ans; } void getFail() { queue<int>q; last[0]=f[0]=0; for(int c=0; c<2; c++) { int u=ch[0][c]; if(u){q.push(u);f[u]=last[u]=0;} } while(!q.empty()) { int r = q.front(); q.pop(); for(int c=0; c<2; c++) { int u=ch[r][c]; if(u==0) { continue; } q.push(u); int v=f[r]; while(v&&ch[v][c]==0)v=f[v]; f[u]=ch[v][c]; last[u]=val[f[u]]?f[u]:last[f[u]]; } } } }ac,buf; char s[5000010],temp[5000010]; void dfs(int u,int v) { queue<int>U,V; U.push(0);V.push(0); while(!U.empty()) { u=U.front();U.pop(); v=V.front();V.pop(); for(int i=0;i<2; i++) if(buf.ch[v][i]) { int e2=buf.ch[v][i]; if(ac.ch[u][i]==0) { ac.ch[ac.sz][0]=ac.ch[ac.sz][1]=0; ac.val[ac.sz]=0; ac.ch[u][i]=ac.sz++; } int e1=ac.ch[u][i]; ac.val[e1]|=buf.val[e2]; U.push(e1);V.push(e2); } } } void join() { dfs(0,0); buf.init(); ac.getFail(); } int main() { int cas; scanf("%d",&cas); for(int cc=1; cc<=cas ; cc++) { int n; scanf("%d",&n); ac.init(); buf.init(); int L=0; printf("Case #%d:\n",cc); for(int i=0;i<n; i++) { scanf("%s",temp); int len=strlen(temp+1); s[0]=temp[0]; for(int i=0; i<len; i++) s[i+1]=temp[1+((i+L)%len)]; if(s[0]=='+') { if(buf.search(s+1,len)|| ac.search(s+1,len))continue; buf.insert(s+1,len); buf.getFail(); if(buf.sz>800)join(); }else { L=buf.find(s+1,len)+ac.find(s+1,len); printf("%d\n",L); } } } return 0; }