【学习笔记】AC自动机

AC自动机是以trie结构为基础,结合KMP算法思想构建的,用于解决多模式串匹配问题。
它的构建方式分为以下几步:
\(1.\) 建立trie树
\(2.\) 构建失配(fail)指针
其中 fail 指针指向的是当前节点的状态的后缀所对应的状态。

这里明确一下,trie树中的每个节点表示的是一个状态,及某个模式串的前缀。
因此若这个状态可以在文本串中匹配,则它的后缀必定也能在文本串中匹配。

构建 fail 指针的过程是一个 bfs,对于每个点,用其父亲的 fail 来更新自己的 fail
即,父亲的后缀加上自己当前的这一位字符,就是自己的后缀了。
核心代码:

复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
for(int i=0;i<=25;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=25;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; q.push(tr[u][i]); } else{ // 重构trie树结构 tr[u][i]=tr[fail[u]][i]; } } }

代码中有一点是没有提到的,就是加了注释的那一句。即,若没有 u+(i+'a') 这个状态,就让这个状态指向这个状态的后缀。于是在前面那句 if 和以后的匹配中,可以避免 while 循环不停跳 fail 的情况。

板子题:HDU 2222
匹配时不断跳 fail 就行了,即若这个状态能匹配,则它的后缀也能匹配。注意打标记避免重复匹配,否则时间复杂度无法保证。

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
#include<bits/stdc++.h> #define ll long long #define il inline #define read(x){\ char ch;\ int fu=1;\ while(!isdigit(ch=getchar()))\ fu-=(ch=='-')<<1;\ x=ch&15;\ while(isdigit(ch=getchar()))\ x=(x<<1)+(x<<3)+(ch&15);\ x*=fu;\ } using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=1e4+5,maxm=1e6+5; int T,n,tot,tr[maxn*50][30]; int fail[maxn*50],end[maxn*50]; char s[maxm]; queue<int> q; namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ read(T); while(T--){ tot=0; fail[0]=end[0]=0; for(int i=0;i<=25;i++){ tr[0][i]=0; } read(n); while(n--){ scanf(" %s",s+1); int len=strlen(s+1); int p=0; for(int i=1;i<=len;i++){ int d=s[i]-'a'; if(!tr[p][d]){ tr[p][d]=++tot; fail[tot]=end[tot]=0; for(int j=0;j<=25;j++){ tr[tot][j]=0; } } p=tr[p][d]; } end[p]++; } for(int i=0;i<=25;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=25;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } scanf(" %s",s+1); n=strlen(s+1); int res=0,p=0; for(int i=1;i<=n;i++){ p=tr[p][s[i]-'a']; for(int j=p;j&&~end[j];j=fail[j]){ res+=end[j]; end[j]=-1; } } printf("%d\n",res); } return 0; } } int main(){return asbt::main();}

Luogu P5231
这里就要注意到,trie树上本来的从父亲连向儿子的边,和构建 fail 指针时为了方便重构的边的区别了。在代码中,后者用 \(tra\) 表示。
每个点都是它的子节点的前缀,于是可以将父亲节点的答案(如果有的话)传给儿子节点,最后在叶子节点统计答案即可。因为是动态开点,每个点的子节点编号肯定比它自己大,因此不用 dfs 遍历整棵树,直接顺序遍历所有节点就可以了。

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
#include<bits/stdc++.h> #define ll long long #define il inline #define read(x){\ char ch;\ int fu=1;\ while(!isdigit(ch=getchar()))\ fu-=(ch=='-')<<1;\ x=ch&15;\ while(isdigit(ch=getchar()))\ x=(x<<1)+(x<<3)+(ch&15);\ x*=fu;\ } #define mp make_pair #define lwrb lower_bound #define uprb upper_bound using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=1e7+5,maxm=1e5+5; int n,m,tot,tr[maxn][4],fail[maxn]; int dep[maxn],zhi[maxn],ans[maxm]; int tra[maxn][4]; char s[maxn],t[105]; multimap<int,int> hao; queue<int> q; il int gid(char x){ if(x=='E'){ return 0; } if(x=='S'){ return 1; } if(x=='W'){ return 2; } return 3; } namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ read(n)read(m); scanf(" %s",s+1); for(int i=1;i<=m;i++){ scanf(" %s",t+1); int p=0,len=strlen(t+1); for(int j=1;j<=len;j++){ int d=gid(t[j]); if(!tr[p][d]){ tra[p][d]=tr[p][d]=++tot; } p=tr[p][d]; dep[p]=j; } hao.insert(mp(p,i)); } for(int i=0;i<=3;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=3;i++){ if(tr[u][i]){ fail[tr[u][i]]=tra[fail[u]][i]; q.push(tr[u][i]); } else{ tra[u][i]=tra[fail[u]][i]; } } } int p=0; for(int i=1;i<=n;i++){ p=tra[p][gid(s[i])]; for(int j=p;j&&!zhi[j];j=fail[j]){ zhi[j]=dep[j]; } } for(int i=1;i<=tot;i++){ bool flag=0; for(int j=0;j<=3;j++){ if(tr[i][j]){ flag=1; zhi[tr[i][j]]=max(zhi[tr[i][j]],zhi[i]); } } if(!flag){ auto l=hao.lwrb(i); auto r=hao.uprb(i); while(l!=r){ ans[l++->second]=zhi[i]; } } } for(int i=1;i<=m;i++){ printf("%d\n",ans[i]); } return 0; } } int main(){return asbt::main();}

然后我突然发现了一个问题,如果有一个模式串是另一个模式串的前缀的话,那它的结尾就不是叶子节点,就会无法统计到答案。比如下面这组数据:

复制代码
  • 1
  • 2
  • 3
  • 4
4 2 NNSS NN NNS

刚才这篇代码会对第一个模式串输出 0。因此需要在每个节点记一个 end,然后把统计答案的判断条件改为 end[i]。(为什么不在每个节点都统计一遍答案呢,因为时间会达到 \(O(n\log m)\),有可能会炸。)然而洛谷上显然并没有这样的数据。
真正正确的代码:

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
#include<bits/stdc++.h> #define ll long long #define il inline #define read(x){\ char ch;\ int fu=1;\ while(!isdigit(ch=getchar()))\ fu-=(ch=='-')<<1;\ x=ch&15;\ while(isdigit(ch=getchar()))\ x=(x<<1)+(x<<3)+(ch&15);\ x*=fu;\ } #define mp make_pair #define lwrb lower_bound #define uprb upper_bound using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=1e7+5,maxm=1e5+5; int n,m,tot,tr[maxn][4],fail[maxn]; int dep[maxn],zhi[maxn],ans[maxm]; int tra[maxn][4]; char s[maxn],t[105]; multimap<int,int> hao; queue<int> q; bitset<maxn> end; il int gid(char x){ if(x=='E'){ return 0; } if(x=='S'){ return 1; } if(x=='W'){ return 2; } return 3; } namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ read(n)read(m); scanf(" %s",s+1); for(int i=1;i<=m;i++){ scanf(" %s",t+1); int p=0,len=strlen(t+1); for(int j=1;j<=len;j++){ int d=gid(t[j]); if(!tr[p][d]){ tra[p][d]=tr[p][d]=++tot; } p=tr[p][d]; dep[p]=j; } end[p]=1; hao.insert(mp(p,i)); } for(int i=0;i<=3;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=3;i++){ if(tr[u][i]){ fail[tr[u][i]]=tra[fail[u]][i]; q.push(tr[u][i]); } else{ tra[u][i]=tra[fail[u]][i]; } } } int p=0; for(int i=1;i<=n;i++){ p=tra[p][gid(s[i])]; for(int j=p;j&&!zhi[j];j=fail[j]){ zhi[j]=dep[j]; } } for(int i=1;i<=tot;i++){ for(int j=0;j<=3;j++){ if(tr[i][j]){ zhi[tr[i][j]]=max(zhi[tr[i][j]],zhi[i]); } } if(end[i]){ auto l=hao.lwrb(i); auto r=hao.uprb(i); while(l!=r){ ans[l++->second]=zhi[i]; } } } for(int i=1;i<=m;i++){ printf("%d\n",ans[i]); } return 0; } } int main(){return asbt::main();} /* 4 2 NNSS NN NNS */

一本通 1482
计算出现次数,那也不难。只需给匹配到的每个点的出现次数都加一,最后对于每个模式串对应到字典树上的编号查询即可。这样做的复杂度是 \(O(n|S|)\) 的,过不去。于是考虑只给 fail 链的链头加上贡献,再跑一遍拓扑排序即可。(因为 fail 指针连成的链不可能有环。)
另外,上面那道题用 multimap 存储字典树的节点对应的模式串编号的方式非常的蠢,直接开个数组存每个模式串对应的节点编号就行了。否则在一本通上会喜提RE+WA。

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
#include<bits/stdc++.h> #define ll long long #define il inline using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=1e6+5; int n,len[205],fail[maxn],num[maxn]; int tot,tr[maxn][30],deg[maxn],hao[205]; string s[205]; queue<int> q; namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ ios::sync_with_stdio(0),cin.tie(0); cin>>n; for(int i=1;i<=n;i++){ cin>>s[i]; len[i]=s[i].size(); s[i]=" "+s[i]; int p=0; for(int j=1;j<=len[i];j++){ int d=s[i][j]-'a'; if(!tr[p][d]){ tr[p][d]=++tot; } p=tr[p][d]; } hao[i]=p; } for(int i=0;i<=25;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=25;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; if(fail[tr[u][i]]){ deg[fail[tr[u][i]]]++; } q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } for(int i=1;i<=n;i++){ int p=0; for(int j=1;j<=len[i];j++){ p=tr[p][s[i][j]-'a']; num[p]++; } } for(int i=1;i<=tot;i++){ if(!deg[i]){ q.push(i); } } while(q.size()){ int u=q.front(); q.pop(); int v=fail[u]; if(!v){ continue; } num[v]+=num[u]; if(--deg[v]==0){ q.push(v); } } for(int i=1;i<=n;i++){ cout<<num[hao[i]]<<"\n"; } return 0; } } int main(){return asbt::main();}

Luogu P2444
题目意思是要找到一个无法匹配任何模式串的无限长的文本串。
不难发现,这样的串在AC自动机上匹配时会跑出一个环。
因此只需要在AC自动机上找有没有不经过任何模式串的结尾节点的环就可以了。
但是这里会有问题:某个点或许不是模式串的结尾,但是它的后缀有可能是。
于是在求 fail 时,如果这个点的 fail 是结尾节点,那么就把它也设为结尾节点(dfs时不能经过它)。
代码不难写,但会TLE。
原因是由于要找环,在遍历完这个点的子树后就将这个点的访问状态(vis2 数组)改成0了,但在之后再遍历这个点是没有意义的。
因此就再记一个 vis1 数组,记录是否遍历过它就行了。

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
#include<bits/stdc++.h> #define ll long long #define il inline #define read(x){\ char ch;\ int fu=1;\ while(!isdigit(ch=getchar()))\ fu-=(ch=='-')<<1;\ x=ch&15;\ while(isdigit(ch=getchar()))\ x=(x<<1)+(x<<3)+(ch&15);\ x*=fu;\ } using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=3e4+5; int n,tot,tr[maxn][2],fail[maxn]; char s[maxn]; bitset<maxn> end,vis1,vis2; queue<int> q; il void dfs(int u){ vis1[u]=vis2[u]=1; for(int i=0;i<=1;i++){ if(vis2[tr[u][i]]){ puts("TAK"); exit(0); } if(end[tr[u][i]]||vis1[tr[u][i]]){ continue; } dfs(tr[u][i]); } vis2[u]=0; } namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ read(n); while(n--){ scanf(" %s",s+1); int len=strlen(s+1),p=0; for(int i=1;i<=len;i++){ int d=s[i]&15; if(!tr[p][d]){ tr[p][d]=++tot; } p=tr[p][d]; } end[p]=1; } for(int i=0;i<=1;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=1;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; if(end[fail[tr[u][i]]]){ end[tr[u][i]]=1; } q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } // for(int i=1;i<=tot;i++){ // cout<<tr[i][0]<<" "<<tr[i][1]<<" "<<fail[i]<<"\n"; // } dfs(0); puts("NIE"); return 0; } } int main(){return asbt::main();}

[HNOI2006] 最短母串问题
要求字典序最小,考虑一位一位贪心,bfs。先建出 AC 自动机,然后 bfs 同时记录路径即可。需要状压。

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
#include<bits/stdc++.h> #define ll long long #define il inline #define pii pair<int,int> #define mp make_pair #define fir first #define sec second using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=605,maxm=(1<<12)+5; int n,tr[maxn][30],tot; int fail[maxn],sta[maxn]; pair<pii,int> pre[maxn][maxm]; bool vis[maxn][maxm]; char ans[maxn]; string s; queue<int> q; queue<pii> q2; namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ // cout<<cplx::usdmem(); ios::sync_with_stdio(0),cin.tie(0); cin>>n; for(int i=1,p;i<=n;i++){ cin>>s; p=0; for(int j=0,d;j<s.size();j++){ d=s[j]-'A'; if(!tr[p][d]){ tr[p][d]=++tot; } p=tr[p][d]; } sta[p]|=1<<(i-1); } for(int i=0;i<=25;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=25;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; sta[tr[u][i]]|=sta[fail[tr[u][i]]]; q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } vis[0][0]=1,q2.push(mp(0,0)); // puts("666"); while(q2.size()){ int u=q2.front().fir,S=q2.front().sec; // puts("666"); // cout<<u<<"\n"; q2.pop(); // puts("777"); if(S==((1<<n)-1)){ int cnt=0; while(u){ // puts("666"); // cout<<u<<"\n"; ans[++cnt]='A'+pre[u][S].sec; pii tmp=pre[u][S].fir; u=tmp.fir,S=tmp.sec; } for(int i=cnt;i;i--){ cout<<ans[i]; } return 0; } for(int i=0;i<=25;i++){ if(tr[u][i]&&!vis[tr[u][i]][S|sta[tr[u][i]]]){ vis[tr[u][i]][S|sta[tr[u][i]]]=1; pre[tr[u][i]][S|sta[tr[u][i]]]=mp(mp(u,S),i); q2.push(mp(tr[u][i],S|sta[tr[u][i]])); } } } return 0; } } int main(){return asbt::main();}

「一本通 2.4 练习 6」文本生成器
正难则反,设 \(dp_{i,j}\) 表示到 \(i\) 点串长为 \(j\) 的方案数,从父亲向儿子转移即可。注意转移顺序,应该先枚举 \(i\)

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
#include<bits/stdc++.h> #define ll long long #define il inline using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=6e3+5,mod=1e4+7; int n,m,tot,tr[maxn][30]; int fail[maxn],dp[maxn][105]; bool end[maxn]; queue<int> q; string s; il int qpow(int x,int y){ int res=1; while(y){ if(y&1){ res=res*x%mod; } y>>=1,x=x*x%mod; } return res; } namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ ios::sync_with_stdio(0),cin.tie(0); cin>>n>>m; for(int i=1,p;i<=n;i++){ cin>>s; p=0; for(int j=0,d;j<s.size();j++){ d=s[j]-'A'; if(!tr[p][d]){ tr[p][d]=++tot; } p=tr[p][d]; } end[p]=1; } for(int i=0;i<=25;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=25;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; if(end[fail[tr[u][i]]]){ end[tr[u][i]]=1; } q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } dp[0][0]=1; for(int i=1;i<=m;i++){ for(int j=0;j<=tot;j++){ for(int k=0;k<=25;k++){ if(!end[tr[j][k]]){ (dp[tr[j][k]][i]+=dp[j][i-1])%=mod; } } } } int ans=0; for(int i=0;i<=tot;i++){ (ans+=dp[i][m])%=mod; } cout<<(qpow(26,m)-ans+mod)%mod; return 0; } } int main(){return asbt::main();}

[BZOJ 2905]背单词
\(dp_i\) 表示前 \(i\) 个字符串的最大答案。则 \(i\) 只能从它的子串转移。考虑子串本质就是从前面删几个字符再从后面删几个字符串,具体到 AC 自动机上就是 \(i\) 的某个前缀跳了若干个 fail。于是我们从 \(fail_i\)\(i\) 连边,这样就建出了一棵 fail 树,对于每个前缀去查询根链[1]的最大值,然后再更新。用线段树维护即可。

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
#include<bits/stdc++.h> #define ll long long #define il inline #define pb push_back #define lid id<<1 #define rid id<<1|1 using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=3e5+5,inf=0x3f3f3f3f; int T,n,tr[maxn][30]; int dfn[maxn],sz[maxn]; int fail[maxn],hao[maxn]; int fa[maxn],cnt,a[maxn]; vector<int> e[maxn]; string s; queue<int> q; il void dfs(int u){ // cout<<u<<"\n"; dfn[u]=++cnt; sz[u]=1; for(int v:e[u]){ dfs(v); sz[u]+=sz[v]; } } struct{ int tr[maxn<<2]; il void build(int id,int l,int r){ tr[id]=0; if(l==r){ return ; } int mid=(l+r)>>1; build(lid,l,mid); build(rid,mid+1,r); } il void upd(int id,int L,int R,int l,int r,int v){ if(L>=l&&R<=r){ tr[id]=max(tr[id],v); return ; } int mid=(L+R)>>1; if(l<=mid){ upd(lid,L,mid,l,r,v); } if(r>mid){ upd(rid,mid+1,R,l,r,v); } } il int query(int id,int l,int r,int p){ int res=tr[id]; if(l==r){ return res; } int mid=(l+r)>>1; return max(res,p<=mid?query(lid,l,mid,p):query(rid,mid+1,r,p)); } }SG; il void solve(){ cin>>n; int tot=0; for(int i=0;i<=25;i++){ tr[0][i]=0; } for(int i=1,p;i<=n;i++){ cin>>s>>a[i]; p=0; for(int j=0,d;j<s.size();j++){ d=s[j]-'a'; if(!tr[p][d]){ tr[p][d]=++tot; for(int k=0;k<=25;k++){ tr[tot][k]=0; } fa[tot]=p; } p=tr[p][d]; } hao[i]=p; } for(int i=0;i<=25;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); // cout<<u<<"\n"; q.pop(); for(int i=0;i<=25;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } // cout<<tot<<"\n"; for(int i=1;i<=tot;i++){ e[fail[i]].pb(i); } cnt=0; dfs(0); // for(int i=0;i<=tot;i++){ // cout<<i<<" "<<dfn[i]<<" "<<sz[i]<<"\n"; // } SG.build(1,1,cnt); int ans=0; for(int i=1,res,p;i<=n;i++){ p=hao[i],res=0; // cout<<i<<"-\n"; while(p){ // cout<<p<<"\n"; res=max(res,SG.query(1,1,cnt,dfn[p])); p=fa[p]; } res=max(res,res+a[i]); ans=max(ans,res); SG.upd(1,1,cnt,dfn[hao[i]],dfn[hao[i]]+sz[hao[i]]-1,res); // cout<<res<<"\n"; } cout<<ans<<"\n"; for(int i=0;i<=tot;i++){ fa[i]=fail[i]=0; e[i].clear(); } } namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ // freopen("433.in","r",stdin); ios::sync_with_stdio(0),cin.tie(0); cin>>T; while(T--){ solve(); } return 0; } } int main(){return asbt::main();}

[JSOI2009]密码
\(dp_{i,j,S}\) 表示填了 \(i\) 位,在 AC 自动机上的 \(j\) 号节点,当前覆盖的字符串集位 \(S\) 的方案数。于是有转移:

\[\large{dp_{i,j,S}\to dp_{i+1,tr_{j,k},S\operatorname{or}sta_{tr_{j,k}}}} \]

其中 \(tr_{j,k}\) 表示 AC 自动机上 \(j\) 点加上字符 \(k\) 的节点,\(sta_j\) 表示以 \(j\) 点为结尾的字符串构成的集合,\(\operatorname{or}\) 表示按位或。
输出方案,先记忆化搜索确定每个状态 \((i,j,S)\) 能否转移到合法状态,再一遍 dfs 输出即可。

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
#include<bits/stdc++.h> #define ll long long #define il inline using namespace std; namespace asbt{ namespace cplx{bool begin;} const int maxn=(1<<10)+5; int n,m,tr[105][30],tot; int fail[105],sta[105]; ll dp[30][105][maxn]; bool vis[30][105][maxn]; bool f[30][105][maxn]; char ans[maxn]; string s; queue<int> q; il bool dfs1(int i,int j,int S){ if(vis[i][j][S]){ return f[i][j][S]; } vis[i][j][S]=1; if(i==m){ return f[i][j][S]=S==(1<<n)-1; } bool &res=f[i][j][S]; for(int k=0;k<=25;k++){ res|=dfs1(i+1,tr[j][k],S|sta[tr[j][k]]); } return res; } il void dfs2(int i,int j,int S){ if(i==m){ for(int k=1;k<=m;k++){ cout<<ans[k]; } cout<<"\n"; return ; } for(int k=0;k<=25;k++){ if(f[i+1][tr[j][k]][S|sta[tr[j][k]]]){ ans[i+1]=k+'a'; dfs2(i+1,tr[j][k],S|sta[tr[j][k]]); } } } namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ ios::sync_with_stdio(0),cin.tie(0); cin>>m>>n; for(int i=1,p;i<=n;i++){ cin>>s; p=0; for(int j=0,d;j<s.size();j++){ d=s[j]-'a'; if(!tr[p][d]){ tr[p][d]=++tot; } p=tr[p][d]; } sta[p]|=1<<(i-1); } for(int i=0;i<=25;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<=25;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; sta[tr[u][i]]|=sta[fail[tr[u][i]]]; q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } dp[0][0][0]=1; for(int i=0;i<=m;i++){ for(int j=0;j<=tot;j++){ for(int S=0;S<1<<n;S++){ if(!dp[i][j][S]){ continue; } for(int k=0;k<=25;k++){ dp[i+1][tr[j][k]][S|sta[tr[j][k]]]+=dp[i][j][S]; } } } } ll ans=0; for(int i=0;i<=tot;i++){ ans+=dp[m][i][(1<<n)-1]; } cout<<ans<<"\n"; if(ans>42){ return 0; } dfs1(0,0,0); dfs2(0,0,0); return 0; } } int main(){return asbt::main();}

[bzoj2553]禁忌
考虑如果暴力 DP 的话,时间复杂度会超标,于是矩阵加速。
但是在 DP 的过程中,期望又要乘又要加的,用矩阵很难转移。于是考虑用矩阵去算概率,新开一行去存期望。
这个贪心是很显然的:当匹配完了一个单词时,直接从头开始尝试匹配下一个单词。由于题目要求不能重叠,这不仅是策略上的优化还是正确性的保证。
具体地,假设要从节点 \(j\) 转移到 \(k\),如果 \(k\) 是某个单词的结尾,那就把贡献加给根,同时加给期望。否则就只能加给 \(k\) 了。
时间复杂度 \(O((\sum|T_i|)^3\log len)\)

Code
复制代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
#include<bits/stdc++.h> #define ll long long #define il inline using namespace std; namespace asbt{ namespace cplx{bool begin;} int n,m,tr[80][30]; int tot,fail[80]; double ab; bool end[80]; string s; queue<int> q; struct juz{ double mat[80][80]; juz(){ for(int i=0;i<=tot;i++){ for(int j=0;j<=tot;j++){ mat[i][j]=0; } } } il double*operator[](int x){ return mat[x]; } il juz operator*(juz x)const{ juz res; for(int i=0;i<=tot;i++){ for(int j=0;j<=tot;j++){ for(int k=0;k<=tot;k++){ res[i][j]+=mat[i][k]*x[k][j]; } } } return res; } }bas; il juz qpow(juz x,int y){ juz res; for(int i=0;i<=tot;i++){ res[i][i]=1; } while(y){ if(y&1){ res=res*x; } y>>=1,x=x*x; } return res; } namespace cplx{ bool end; il double usdmem(){return (&begin-&end)/1048576.0;} } int main(){ ios::sync_with_stdio(0),cin.tie(0); cin>>n>>m>>ab; for(int i=1,p;i<=n;i++){ cin>>s; p=0; for(int j=0,d;j<s.size();j++){ d=s[j]-'a'; if(!tr[p][d]){ tr[p][d]=++tot; } p=tr[p][d]; } end[p]=1; } for(int i=0;i<ab;i++){ if(tr[0][i]){ q.push(tr[0][i]); } } while(q.size()){ int u=q.front(); q.pop(); for(int i=0;i<ab;i++){ if(tr[u][i]){ fail[tr[u][i]]=tr[fail[u]][i]; end[tr[u][i]]|=end[fail[tr[u][i]]]; q.push(tr[u][i]); } else{ tr[u][i]=tr[fail[u]][i]; } } } // for(int i=0;i<=tot;i++){ // cout<<fail[i]<<" "; // } // puts(""); tot++; for(int i=0;i<tot;i++){ for(int j=0;j<ab;j++){ if(end[tr[i][j]]){ bas[i][0]+=1.0/ab; bas[i][tot]+=1.0/ab; } else{ bas[i][tr[i][j]]+=1.0/ab; } } } bas[tot][tot]=1; printf("%.10f",qpow(bas,m)[0][tot]); return 0; } } int main(){return asbt::main();}

  1. 根链:由 2024 陕西省队队员马思博提出,指一棵树上一个端点在树根的链。(摘自 UKE 的题解) ↩︎

posted @   zhangxy__hp  阅读(42)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
展开