POJ 1625 Censored!(AC自动机 + DP + 大数 + 拓展ASCII处理)题解
题意:给出n个字符,p个病毒串,要你求出长度为m的不包含病毒串的主串的个数
思路:不给取模最恶劣情况$50^{50}$,所以用高精度板子。因为m比较小,可以直接用DP写。
因为给你的串的字符包含拓展ASCII码(128~256),由于编译器的原因,char的读入范围在-128~127或者0~255之间不确定,所以你读一个拓展ASCII码的字符后可能是负的,那么你处理的时候要注意加130。或者你直接用map映射。或者用unsigned char。
代码:
#include<cmath> #include<set> #include<map> #include<queue> #include<cstdio> #include<vector> #include<cstring> #include <iostream> #include<algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn = 100 + 5; const int M = 50 + 5; const ull seed = 131; const double INF = 1e20; const int MOD = 100000; int n, m, p; int tol; int reflect[600]; struct BigInt{ const static int mod = 10000; int a[50], len; BigInt(){ memset(a, 0, sizeof(a)); len = 1; } void set(int v){ memset(a, 0, sizeof(a)); len = 0; do{ a[len++] = v % mod; v /= mod; }while(v); } BigInt operator + (const BigInt &b) const{ BigInt res; res.len = max(len, b.len); for(int i = 0; i <= res.len; i++) res.a[i] = 0; for(int i = 0; i < res.len; i++){ res.a[i] += ((i < len)? a[i] : 0) + ((i < b.len)? b.a[i] : 0); res.a[i + 1] += res.a[i] / mod; res.a[i] %= mod; } if(res.a[res.len] > 0) res.len++; return res; } BigInt operator * (const BigInt &b) const{ BigInt res; for(int i = 0; i < len; i++){ int up = 0; for(int j = 0; j < b.len; j++){ int temp = a[i] * b.a[j] + res.a[i + j] + up; res.a[i + j] = temp % mod; up = temp / mod; } if(up != 0) res.a[i + b.len] = up; } res.len = len + b.len; while(res.a[res.len - 1] == 0 && res.len > 1) res.len--; return res; } void output(){ printf("%d", a[len - 1]); for(int i = len - 2; i >= 0; i--){ printf("%04d", a[i]); } printf("\n"); } }; BigInt dp[55][maxn]; struct Aho{ struct state{ int next[51]; int fail, cnt; }node[maxn]; int size; queue<int> q; void init(){ size = 0; newtrie(); while(!q.empty()) q.pop(); } int newtrie(){ memset(node[size].next, 0, sizeof(node[size].next)); node[size].cnt = node[size].fail = 0; return size++; } void insert(char *s){ int len = strlen(s); int now = 0; for(int i = 0; i < len; i++){ int c = reflect[int(s[i]) + 130]; if(node[now].next[c] == 0){ node[now].next[c] = newtrie(); } now = node[now].next[c]; } node[now].cnt = 1; } void build(){ node[0].fail = -1; q.push(0); while(!q.empty()){ int u = q.front(); q.pop(); if(node[node[u].fail].cnt && u) node[u].cnt = 1; for(int i = 0; i < 51; i++){ if(!node[u].next[i]){ if(u == 0) node[u].next[i] = 0; else node[u].next[i] = node[node[u].fail].next[i]; } else{ if(u == 0) node[node[u].next[i]].fail = 0; else{ int v = node[u].fail; while(v != -1){ if(node[v].next[i]){ node[node[u].next[i]].fail = node[v].next[i]; break; } v = node[v].fail; } if(v == -1) node[node[u].next[i]].fail = 0; } q.push(node[u].next[i]); } } } } void query(){ BigInt one; one.set(1); for(int i = 0; i <= m; i++){ for(int j = 0; j < size; j++){ dp[i][j].set(0); } } for(int i = 0; i < tol; i++){ if(node[node[0].next[i]].cnt == 0){ dp[1][node[0].next[i]] = dp[1][node[0].next[i]] + one; } } for(int i = 1; i <= m; i++){ for(int j = 0; j < size; j++){ for(int k = 0; k < tol; k++){ if(node[node[j].next[k]].cnt == 0){ dp[i + 1][node[j].next[k]] = dp[i + 1][node[j].next[k]] + dp[i][j]; } } } } BigInt ans; ans.set(0); for(int i = 0; i < size; i++){ if(node[i].cnt == 0){ ans = ans + dp[m][i]; } } ans.output(); } }ac; char s[100]; int main(){ while(~scanf("%d%d%d", &n, &m, &p)){ scanf("%s", s); tol = 0; for(int i = 0; i < n; i++){ reflect[int(s[i]) + 130] = tol++; } ac.init(); while(p--){ scanf("%s", s); ac.insert(s); } ac.build(); ac.query(); } return 0; }
#include<cmath> #include<set> #include<map> #include<queue> #include<cstdio> #include<vector> #include<cstring> #include <iostream> #include<algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn = 100 + 5; const int M = 50 + 5; const ull seed = 131; const double INF = 1e20; const int MOD = 100000; int n, m, p; int tol; map<char, int> reflect; struct BigInt{ const static int mod = 10000; int a[50], len; BigInt(){ memset(a, 0, sizeof(a)); len = 1; } void set(int v){ memset(a, 0, sizeof(a)); len = 0; do{ a[len++] = v % mod; v /= mod; }while(v); } BigInt operator + (const BigInt &b) const{ BigInt res; res.len = max(len, b.len); for(int i = 0; i <= res.len; i++) res.a[i] = 0; for(int i = 0; i < res.len; i++){ res.a[i] += ((i < len)? a[i] : 0) + ((i < b.len)? b.a[i] : 0); res.a[i + 1] += res.a[i] / mod; res.a[i] %= mod; } if(res.a[res.len] > 0) res.len++; return res; } BigInt operator * (const BigInt &b) const{ BigInt res; for(int i = 0; i < len; i++){ int up = 0; for(int j = 0; j < b.len; j++){ int temp = a[i] * b.a[j] + res.a[i + j] + up; res.a[i + j] = temp % mod; up = temp / mod; } if(up != 0) res.a[i + b.len] = up; } res.len = len + b.len; while(res.a[res.len - 1] == 0 && res.len > 1) res.len--; return res; } void output(){ printf("%d", a[len - 1]); for(int i = len - 2; i >= 0; i--){ printf("%04d", a[i]); } printf("\n"); } }; BigInt dp[55][maxn]; struct Aho{ struct state{ int next[51]; int fail, cnt; }node[maxn]; int size; queue<int> q; void init(){ size = 0; newtrie(); while(!q.empty()) q.pop(); } int newtrie(){ memset(node[size].next, 0, sizeof(node[size].next)); node[size].cnt = node[size].fail = 0; return size++; } void insert(char *s){ int len = strlen(s); int now = 0; for(int i = 0; i < len; i++){ int c = reflect[s[i]]; if(node[now].next[c] == 0){ node[now].next[c] = newtrie(); } now = node[now].next[c]; } node[now].cnt = 1; } void build(){ node[0].fail = -1; q.push(0); while(!q.empty()){ int u = q.front(); q.pop(); if(node[node[u].fail].cnt && u) node[u].cnt = 1; for(int i = 0; i < 51; i++){ if(!node[u].next[i]){ if(u == 0) node[u].next[i] = 0; else node[u].next[i] = node[node[u].fail].next[i]; } else{ if(u == 0) node[node[u].next[i]].fail = 0; else{ int v = node[u].fail; while(v != -1){ if(node[v].next[i]){ node[node[u].next[i]].fail = node[v].next[i]; break; } v = node[v].fail; } if(v == -1) node[node[u].next[i]].fail = 0; } q.push(node[u].next[i]); } } } } void query(){ BigInt one; one.set(1); for(int i = 0; i <= m; i++){ for(int j = 0; j < size; j++){ dp[i][j].set(0); } } for(int i = 0; i < tol; i++){ if(node[node[0].next[i]].cnt == 0){ dp[1][node[0].next[i]] = dp[1][node[0].next[i]] + one; } } for(int i = 1; i <= m; i++){ for(int j = 0; j < size; j++){ for(int k = 0; k < tol; k++){ if(node[node[j].next[k]].cnt == 0){ dp[i + 1][node[j].next[k]] = dp[i + 1][node[j].next[k]] + dp[i][j]; } } } } BigInt ans; ans.set(0); for(int i = 0; i < size; i++){ if(node[i].cnt == 0){ ans = ans + dp[m][i]; } } ans.output(); } }ac; char s[100]; int main(){ while(~scanf("%d%d%d", &n, &m, &p)){ scanf("%s", s); tol = 0; for(int i = 0; i < n; i++){ reflect[s[i]] = tol++; } ac.init(); while(p--){ scanf("%s", s); ac.insert(s); } ac.build(); ac.query(); } return 0; }