kuangbin专题十七:AC自动机
思路:AC自动机模板。
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<queue> using namespace std; const int maxn = 1e6+5; int tr[maxn][26], tot; int e[maxn], fail[maxn]; char s[maxn]; void insert(char *s){ int u = 0; for(int i = 1; s[i]; i++){ if(!tr[u][s[i]-'a']) tr[u][s[i]-'a'] = ++tot; u = tr[u][s[i]-'a']; } e[u]++; } queue<int> q; void build(){ for(int i = 0; i < 26; i++) if(tr[0][i]) q.push(tr[0][i]); while(q.size()){ int u = q.front(); q.pop(); for(int i = 0; i < 26; i++) if(tr[u][i]) fail[tr[u][i]] = tr[fail[u]][i], q.push(tr[u][i]); else tr[u][i] = tr[fail[u]][i]; } } int query(char *t){ int u = 0, res = 0; for(int i = 1; t[i]; i++){ u = tr[u][t[i]-'a']; for(int j = u; j && e[j]!=-1; j = fail[j]) res += e[j], e[j] = -1; } return res; } void init(){ tot = 0; memset(tr, 0, sizeof(tr)); memset(fail, 0, sizeof(fail)); memset(e, 0, sizeof(e)); } int main(){ int T; scanf("%d", &T); while(T--){ int n; init(); scanf("%d", &n); for(int i = 0; i < n; i++) scanf("%s", s+1), insert(s); scanf("%s", s+1); build(); printf("%d\n", query(s)); } return 0; }
思路:注意字符串范围是0-128。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<queue> using namespace std; const int maxn = 500 * 200 + 5; int tr[maxn][130], e[maxn], fail[maxn]; int tot; bool vis[505]; char s[10005]; void insert(int idx, char *s){ int u = 0; for(int i = 1; s[i]; i++){ if(!tr[u][s[i]]) tr[u][s[i]] = ++tot; u = tr[u][s[i]]; } e[u] = idx; } queue<int> q; void build(){ int u = 0; for(int i = 0; i < 128; i++) if(tr[u][i]) q.push(tr[u][i]); while(q.size()){ u = q.front(); q.pop(); for(int i = 0; i < 128; i++){ if(tr[u][i]) fail[tr[u][i]] = tr[fail[u]][i], q.push(tr[u][i]); else tr[u][i] = tr[fail[u]][i]; } } } void query(char *t){ memset(vis, 0, sizeof(vis)); int u = 0; for(int i = 1; t[i]; i++){ u = tr[u][t[i]]; for(int j = u; j && !vis[e[j]]; j = fail[j]){ if(e[j]) vis[e[j]] = 1; } } } void init(){ memset(tr, 0, sizeof(tr)); memset(e, 0, sizeof(e)); memset(fail, 0, sizeof(fail)); tot = 0; } int main(){ int n, m, res = 0; scanf("%d", &n); for(int i = 0; i < n; i++) scanf("%s", s + 1), insert(i+1, s); build(); scanf("%d", &m); for(int i = 1; i <= m; i++){ scanf("%s", s+1); query(s); int tot = 0; for(int i = 1; i <= 500; i++) if(vis[i]) tot++; if(tot){ printf("web %d:", i); for(int j = 1; j <= 500; j++) if(vis[j]) printf(" %d", j); printf("\n"); res++; } } printf("total: %d\n", res); return 0; }
思路:同上。