KMP学习笔记(再回首)+ AC自动机学习笔记
一.KMP
引入
我们经常遇到字符串匹配问题。比如求一个长为
KMP
那么,有没有一种可以快速匹配的算法呢?这里我们介绍KMP。首先,我们来明确一些概念。我们把前面的
我们发现,在前文的暴力算法中,我们很可能匹配上主串的很长一部分然后失配,这样的话就浪费了一段匹配,因为这个匹配中很可能还存在可以作为模式串匹配起点的子串。我们又发现,这个可以作为起点的串满足它既是
特别地,对于每一个
至于
处理
void KMP(){ int j = 0; for(int i = 2; i<=n; i++){ while(j && a[j+1]!=a[i]){ j = nxt[j]; } if(a[j+1]==a[i]){ j++; } nxt[i] = j; } }
至于匹配,和处理
void work(){ int j = 0; for(int i = 1; i<=m; i++){ while(j && a[j+1]!=b[i]){ j = nxt[j]; } if(a[j+1] == b[i]){ j++; } if(j == n){ printf("%d\n", i-n+1); } } for(int i = 1; i<=n; i++){ printf("%d ", nxt[i]); } }
例题:
洛谷P3375(模板题)
我才不说就是把上面两个代码整一块儿就行了
#include<bits/stdc++.h> using namespace std; const int N = 1e6+100; int n, m; int nxt[N]; char a[N]; char b[N]; void KMP(){ int j = 0; for(int i = 2; i<=n; i++){ while(j && a[j+1]!=a[i]){ j = nxt[j]; } if(a[j+1]==a[i]){ j++; } nxt[i] = j; } } void work(){ int j = 0; for(int i = 1; i<=m; i++){ while(j && a[j+1]!=b[i]){ j = nxt[j]; } if(a[j+1] == b[i]){ j++; } if(j == n){ printf("%d\n", i-n+1); } } for(int i = 1; i<=n; i++){ printf("%d ", nxt[i]); } } int main(){ scanf("%s%s", b+1, a+1); n = strlen(a+1); m = strlen(b+1); KMP(); work(); return 0; }
洛谷P2375 动物园
这个题要求找出长度不超过当前串一半的
#include<bits/stdc++.h> using namespace std; const int N = 1e6+1000, mod = 1e9+7; int q, n; char s[N]; int nxt[N]; long long cnt[N]; long long ans = 1; long long KMP() { ans = 1; int pos = 0; nxt[1] = 0; cnt[1] = 1; for(int i = 2; i<=n; i++) { while(pos&&s[i]!=s[pos+1]) pos = nxt[pos]; if(s[i]==s[pos+1]) pos++; nxt[i] = pos; cnt[i] = cnt[pos]+1;//因为每一个border也可能有小的border,故数量是可以累加的。 }//第一遍 pos = 0; for(int i = 2; i<=n; i++) { while(pos&&s[i]!=s[pos+1]) pos = nxt[pos]; if(s[i]==s[pos+1]) pos++; while(pos&&pos*2>i) pos = nxt[pos]; ans=ans*(cnt[pos]+1)%mod;//cnt表示在pos点上的border总数。 } return ans; } int main() { scanf("%d", &q); while(q--) { scanf("%s", s+1); n = strlen(s+1); long long res = KMP(); printf("%lld\n", res); } return 0; }
CF149E Martian Strings
题意:给定一个主串
考虑如何拼接
#include<bits/stdc++.h> using namespace std; const int N = 1e5+100; char b[N], s[N]; int n, m, lth; int fnxt[N], bnxt[N], posl[1005], posr[1005]; bool KMP(){ int j = 0; for(int i = 2; i<=lth; i++){ while(j && b[j+1]!=b[i]){ j = fnxt[j]; } if(b[j+1]==b[i]){ j++; } fnxt[i] = j; } j = lth+1;bnxt[lth] = lth+1; for(int i = lth-1; i>=1; i--){ while(j<=lth&&b[j-1]!=b[i]){ j = bnxt[j]; } if(b[j-1]==b[i]){//只记录第一次出现的结束位置 j--; } bnxt[i] = j; } memset(posl, 0, sizeof(posl)); j = 0; for(int i = 1; i<=n; i++){ while(j && s[i]!=b[j+1]){ j = fnxt[j]; } if(s[i]==b[j+1]){ j++; } if(!posl[j]&&j){ posl[j] = i; } } j = lth+1; for(int i = n; i>=1; i--){ while(j<=lth&&b[j-1]!=s[i]){ j = bnxt[j]; } if(s[i]==b[j-1]){ j--; } if((j-1)&&posl[j-1]&&posl[j-1]<i&&(j<=lth)){//注意前后缀都要非空 return true; } } return false; } int ans; int main(){ scanf("%s", s+1); n = strlen(s+1); scanf("%d", &m); while(m--){ scanf("%s", b+1); lth = strlen(b+1); if(KMP()){ ans++; } } printf("%d\n", ans); return 0; }
二.AC自动机
引入/前置知识
我们通过KMP可以做到主串匹配单模式串。那么,当模式串多起来后,有什么解决方案呢?
答案是KMP上树。没错,AC自动机就是Trie树和KMP的结合。所以在学习之前,需要先学习KMP和Trie树。
AC自动机
首先我们引入一个概念:失配指针
那么怎么去构建
和KMP类似,就是在Trie树上不断跳
我们可以采用类似并查集路径压缩的方式,将所有节点的空儿子,如
Trie树的构建:
struct Trie{ int cnt; int son[26]; }tr[N]; int idx; void insert(char s[]){ int lth = strlen(s); int u = 0, v; for(int i = 0; i<lth; i++){ v = s[i]-'a'; if(!tr[u].son[v]){ tr[u].son[v] = ++idx; } u = tr[u].son[v]; } tr[u].cnt++; }
指针的构建+Trie图的构建(bfs):
int fail[N]; queue<int> q; void build(){ for(int i = 0; i<26; i++){ if(tr[0].son[i]) q.push(tr[0].son[i]); } while(q.size()){ int u = q.front(); q.pop(); for(int i = 0; i<26; i++){ if(tr[u].son[i]){ fail[tr[u].son[i]] = tr[fail[u]].son[i], q.push(tr[u].son[i]); } else{ tr[u].son[i] = tr[fail[u]].son[i]; } } } }
至于查询,每新增一个字符,都要把整个Trie跳一遍。这里要统计出现的模式串数量,故每次需要清空。
查询:
int query(char t[]){ int u = 0, ret = 0,v, lth = strlen(t); for(int i = 0; i<lth; i++){ int v = t[i]-'a'; u = tr[u].son[v]; for(int j = u; j && tr[j].cnt!=-1; j = fail[j]){//遍历过的模式串没必要再跳 ret+=tr[j].cnt; tr[j].cnt = -1; } } return ret; }
例题
洛谷P3808 模板1
将以上模板套用即可。
#include<bits/stdc++.h> using namespace std; const int N = 8e6+100; struct Trie{ int cnt; int son[26]; }tr[N]; int idx; void insert(char s[]){ int lth = strlen(s); int u = 0, v; for(int i = 0; i<lth; i++){ v = s[i]-'a'; if(!tr[u].son[v]){ tr[u].son[v] = ++idx; } u = tr[u].son[v]; } tr[u].cnt++; } int fail[N]; queue<int> q; void build(){ for(int i = 0; i<26; i++){ if(tr[0].son[i]) q.push(tr[0].son[i]); } while(q.size()){ int u = q.front(); q.pop(); for(int i = 0; i<26; i++){ if(tr[u].son[i]){ fail[tr[u].son[i]] = tr[fail[u]].son[i], q.push(tr[u].son[i]); } else{ tr[u].son[i] = tr[fail[u]].son[i]; } } } } int query(char t[]){ int u = 0, ret = 0, v, lth = strlen(t); for(int i = 0; i<lth; i++){ v = t[i]-'a'; u = tr[u].son[v]; for(int j = u; j && tr[j].cnt!=-1; j = fail[j]){ ret+=tr[j].cnt; tr[j].cnt = -1; } } return ret; } int n, lth; char s[1000010]; int main(){ scanf("%d", &n); for(int i = 1; i<=n; i++){ scanf("%s", s); insert(s); } build(); scanf("%s", s); printf("%d\n", query(s)); return 0; }
洛谷P3796 模板2(加强版)
这次是让你统计出现次数了。发现模式串很少,又发现没有相同的模式串(我一开始还傻傻地开了vector),直接开个桶记录一下次数,最后暴力扫即可。
#include<bits/stdc++.h> using namespace std; const int N = 12000; int n; struct Trie{ int son[26]; int have; }tr[N]; int idx; int fail[N], cnt[160]; void init(){ idx = 0; memset(fail, 0, sizeof(fail)); memset(tr, 0, sizeof(tr)); memset(cnt, 0, sizeof(cnt)); } void insert(char s[], int id){ int lth = strlen(s); int u = 0, v; for(int i = 0; i<lth; i++){ v = s[i]-'a'; if(!tr[u].son[v]){ tr[u].son[v] = ++idx; } u = tr[u].son[v]; } tr[u].have = id; } queue<int> q; void build(){ for(int i = 0; i<26; i++){ if(tr[0].son[i]){ q.push(tr[0].son[i]); } } int u; while(!q.empty()){ u = q.front(); q.pop(); for(int i = 0; i<26; i++){ if(tr[u].son[i]){ fail[tr[u].son[i]] = tr[fail[u]].son[i], q.push(tr[u].son[i]); } else{ tr[u].son[i] = tr[fail[u]].son[i]; } } } } void query(char t[]){ int lth = strlen(t), u = 0, v; for(int i = 0; i<lth; i++){ v = t[i]-'a'; u = tr[u].son[v]; for(int j = u; j; j = fail[j]){ if(tr[j].have){ cnt[tr[j].have]++; } } } } char tmp[160][80], t[1000050]; int mx; int main(){ scanf("%d", &n); while(n){ init(); for(int i = 1; i<=n; i++){ scanf("%s", tmp[i]); insert(tmp[i], i); } build(); scanf("%s", t); query(t); mx = 0; for(int i = 1; i<=n; i++){ mx = max(cnt[i], mx); } printf("%d\n", mx); for(int i = 1; i<=n; i++){ if(cnt[i] == mx){ printf("%s\n", tmp[i]); } } scanf("%d", &n); } return 0; }
洛谷P5357 模板3(二次加强)
乍眼一看,这题和上一道题不一样吗?虽然有重复串,但完全可以直接记录一个串,再让其他相同串指向这个串即可。于是乎——
76pts
代码还是放一下毕竟写半天不容易
#include<bits/stdc++.h> using namespace std; const int N = 2e6+100; int tr[N][26], e[N], idx, he[N]; void insert(char s[], int x){ int lth = strlen(s), u = 0, v; for(int i = 0; i<lth; i++){ int v = s[i]-'a'; if(!tr[u][v]){ tr[u][v] = ++idx; } u = tr[u][v]; } if(!e[u]){ e[u] = x; } else{ he[x] = e[u]; } } int fail[N]; queue<int> q; void build(){ for(int i = 0; i<26; i++){ if(tr[0][i]){ q.push(tr[0][i]); } } int u; while(!q.empty()){ u = q.front(); q.pop(); for(int i = 0; i<26; i++){ if(tr[u][i]){ fail[tr[u][i]] = tr[fail[u]][i]; q.push(tr[u][i]); } else{ tr[u][i] = tr[fail[u]][i]; } } } } int cnt[N]; void query(char t[]){ int lth = strlen(t); int u = 0, v; for(int i = 0; i<lth; i++){ v = t[i]-'a'; u = tr[u][v]; for(int j = u; j; j = fail[j]){ if(e[j]){ cnt[e[j]]++; } } } } int n; char s[N]; int main(){ scanf("%d", &n); for(int i = 1; i<=n; i++){ scanf("%s", s); insert(s, i); } scanf("%s", s); build(); query(s); for(int i = 1; i<=n; i++){ if(he[i]){ printf("%d\n", cnt[he[i]]); } else{ printf("%d\n", cnt[i]); } } return 0; }//只有76pts qwq
让我们来分析一下为什么:因为对于每个点我们都要完整地跳一遍
那可不可以让每个点只经过一次呢?答案是可以的(为什么我一开始想到了给否了qwq)。
答案是拓扑排序。
其实一开始我就在想,既然有的模板串是包含在另一些模板串中,那我们是不是只需要标记一个模板串,然后向上回溯,做一个树形dp就行。然鹅我觉得不好实现,因为
很明显,
这样的话,我们每次只需要在一个节点上修改权值。因为这个节点的权值会贡献给它所有的子串,而它所对应的字符串的所有的子串一定是这个串某一部分的后缀,所以一定是能通过跳
代码:
#include<bits/stdc++.h> using namespace std; const int N = 2e6+100; int tr[N][26], e[N], idx, he[N]; int inde[N]; bool vis[N]; void insert(char s[], int x){ int lth = strlen(s), u = 0, v; for(int i = 0; i<lth; i++){ int v = s[i]-'a'; if(!tr[u][v]){ tr[u][v] = ++idx; } u = tr[u][v]; } e[x] = u; } int fail[N]; queue<int> q; void build(){ for(int i = 0; i<26; i++){ if(tr[0][i]){ q.push(tr[0][i]); } } int u; while(!q.empty()){ u = q.front(); q.pop(); for(int i = 0; i<26; i++){ if(tr[u][i]){ fail[tr[u][i]] = tr[fail[u]][i]; inde[fail[tr[u][i]]]++; q.push(tr[u][i]); } else{ tr[u][i] = tr[fail[u]][i]; } } } } int cnt[N]; void query(char t[]){ int lth = strlen(t); int u = 0, v; for(int i = 0; i<lth; i++){ v = t[i]-'a'; u = tr[u][v]; cnt[u]++; } for(int i = 1; i<=idx; i++){ if(!inde[i]){ q.push(i); } } while(!q.empty()){ int u = q.front(); q.pop(); int v = fail[u]; inde[v]--; cnt[v]+=cnt[u]; if(!inde[v]){ q.push(fail[u]); } } } int n; char s[N]; int main(){ scanf("%d", &n); for(int i = 1; i<=n; i++){ scanf("%s", s); insert(s, i); } scanf("%s", s); build(); query(s); for(int i = 1; i<=n; i++){ printf("%d\n", cnt[e[i]]); } return 0; }