[TJOI2013]单词
题目大意:
给定由$n$个单词组成的文章$(\sum|s_i|\leq10^6)$,统计每个单词出现的次数。
思路:
这一题乍一眼看上去像是用所有单词构造一个AC自动机,然后再用每个字符串去匹配,每次跳转失配指针统计答案。这样确实能在BZOJ上A掉,但是在洛谷上却被卡成90分(考虑一些很长的字符串仅由同一种字母构成,且长度依次递增,这样匹配每个串是$O(n^2)$的)。正确的算法应该是用所有的失配指针构造Fail树,每次匹配不跳转指针,只在当前结点上记录匹配的次数,最后再在Fail树上推一遍。
1 #include<list> 2 #include<queue> 3 #include<cstdio> 4 #include<cctype> 5 #include<cstring> 6 inline int getint() { 7 register char ch; 8 while(!isdigit(ch=getchar())); 9 register int x=ch^'0'; 10 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 11 return x; 12 } 13 const int N=201,L=1e6+1,S=26; 14 int cnt[N],p[N]; 15 char s[L]; 16 class AhoCorasick { 17 private: 18 std::queue<int> q; 19 int ch[L][S],fail[L],sum[L]; 20 std::list<int> list[L],e[L]; 21 int sz,new_node() { 22 return ++sz; 23 } 24 int idx(const char &c) const { 25 return c-'a'; 26 } 27 void dfs(const int &x) { 28 for(std::list<int>::const_iterator i=e[x].begin();i!=e[x].end();i++) { 29 const int &y=*i; 30 dfs(y); 31 sum[x]+=sum[y]; 32 } 33 } 34 public: 35 void insert(const char s[],const int &id) { 36 int p=0; 37 for(register int i=0;s[i];i++) { 38 const int c=idx(s[i]); 39 p=ch[p][c]?:ch[p][c]=new_node(); 40 } 41 list[p].push_front(id); 42 } 43 void get_fail() { 44 for(register int c=0;c<S;c++) { 45 if(ch[0][c]) { 46 q.push(ch[0][c]); 47 e[0].push_front(ch[0][c]); 48 } 49 } 50 while(!q.empty()) { 51 const int &x=q.front(); 52 for(register int c=0;c<S;c++) { 53 int &y=ch[x][c]; 54 if(!y) { 55 y=ch[fail[x]][c]; 56 continue; 57 } 58 fail[y]=ch[fail[x]][c]; 59 e[ch[fail[x]][c]].push_front(y); 60 q.push(y); 61 } 62 q.pop(); 63 } 64 } 65 void find(const char s[],const int &n) { 66 for(register int i=0,p=0;i<n;i++) { 67 sum[p=ch[p][idx(s[i])]]++; 68 } 69 } 70 void stat() { 71 dfs(0); 72 for(register int i=1;i<=sz;i++) { 73 for(register std::list<int>::const_iterator j=list[i].begin();j!=list[i].end();j++) { 74 cnt[*j]+=sum[i]; 75 } 76 } 77 } 78 }; 79 AhoCorasick ac; 80 int main() { 81 const int n=getint(); 82 for(register int i=0;i<n;i++) { 83 p[i]=strlen(s); 84 scanf("%s",&s[p[i]]); 85 ac.insert(&s[p[i]],i); 86 } 87 p[n]=strlen(s); 88 ac.get_fail(); 89 for(register int i=0;i<n;i++) { 90 ac.find(&s[p[i]],p[i+1]-p[i]); 91 } 92 ac.stat(); 93 for(register int i=0;i<n;i++) { 94 printf("%d\n",cnt[i]); 95 } 96 return 0; 97 }