ZOJ 3494 BCD Code(AC自动机 + 数位DP)题解
题意:每位十进制数都能转化为4位二进制数,比如9是1001,127是 000100100111,现在问你,在L到R(R <= $10^{200}$)范围内,有多少数字的二进制表达式不包含模式串。
思路:显然这是一道很明显的数位DP + AC自动机的题目。但是你要是直接把数字转化为二进制,然后在Trie树上数位DP你会遇到一个问题,以为转化为二进制后,前导零变成了四位000,那么你在DP的时候还要考虑前4位是不是都是000那样就要重新跑Trie树,显然这样是很菜(不会)的。那么肯定是想办法要变成十进制跑Trie树。
那我们就预处理出一个bcd[i][j]表示在Trie树上i节点走向数字j可不可行,这样就行了。
代码:
#include<set> #include<map> #include<queue> #include<cmath> #include<string> #include<cstdio> #include<vector> #include<cstring> #include <iostream> #include<algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn = 2000 + 5; const int M = 50 + 5; const ull seed = 131; const int INF = 0x3f3f3f3f; const int MOD = 1000000009; int n, m; int bit[205], pos; ll dp[500][maxn]; int bcd[maxn][10]; struct Aho{ struct state{ int next[10]; int fail, cnt; }node[maxn]; int size; queue<int> q; void init(){ size = 0; newtrie(); while(!q.empty()) q.pop(); } int newtrie(){ memset(node[size].next, 0, sizeof(node[size].next)); node[size].cnt = node[size].fail = 0; return size++; } void insert(char *s){ int len = strlen(s); int now = 0; for(int i = 0; i < len; i++){ int c = s[i] - '0'; if(node[now].next[c] == 0){ node[now].next[c] = newtrie(); } now = node[now].next[c]; } node[now].cnt = 1; } void build(){ node[0].fail = -1; q.push(0); while(!q.empty()){ int u = q.front(); q.pop(); if(node[node[u].fail].cnt && u) node[u].cnt |= node[node[u].fail].cnt; for(int i = 0; i < 10; i++){ if(!node[u].next[i]){ if(u == 0) node[u].next[i] = 0; else node[u].next[i] = node[node[u].fail].next[i]; } else{ if(u == 0) node[node[u].next[i]].fail = 0; else{ int v = node[u].fail; while(v != -1){ if(node[v].next[i]){ node[node[u].next[i]].fail = node[v].next[i]; break; } v = node[v].fail; } if(v == -1) node[node[u].next[i]].fail = 0; } q.push(node[u].next[i]); } } } } ll dfs(int pos, int st, bool Max, bool lead){ if(pos == -1) return 1; if(!Max && !lead && dp[pos][st] != -1) return dp[pos][st]; int top = Max? bit[pos] : 9; ll ans = 0; for(int i = 0; i <= top; i++){ if(lead && i == 0 && pos != 0){ ans = (ans + dfs(pos - 1, 0, Max && i == top, lead && i == 0)) % MOD; continue; } if(bcd[st][i] == -1) continue; ans = (ans + dfs(pos - 1, bcd[st][i], Max && i == top, lead && i == 0)) % MOD; } if(!Max && !lead) dp[pos][st] = ans; return ans; } ll solve(char *s){ pos = 0; int len = strlen(s); for(int i = len - 1; i >= 0; i--){ bit[pos++] = s[i] - '0'; } return dfs(pos - 1, 0, true, true); } char num[10][5] = {"0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111", "1000", "1001"}; void init_bcd(){ memset(bcd, 0, sizeof(bcd)); for(int i = 0; i < size; i++){ for(int j = 0; j < 10; j++){ int v = i; for(int k = 0; k < 4; k++){ v = node[v].next[num[j][k] - '0']; if(node[v].cnt){ bcd[i][j] = -1; break; } } if(bcd[i][j] != -1) bcd[i][j] = v; } } } }ac; char s1[205], s2[205]; int main(){ int T; scanf("%d", &T); while(T--){ memset(dp, -1, sizeof(dp)); scanf("%d", &n); ac.init(); for(int i = 0; i < n; i++){ scanf("%s", s1); ac.insert(s1); } ac.build(); ac.init_bcd(); scanf("%s%s", s1, s2); int lens1 = strlen(s1); int pp = lens1 - 1; while(s1[pp] == '0'){ s1[pp] = '9'; pp--; } s1[pp]--; if(s1[0] == '0' && lens1 > 1){ for(int i = 1; i < lens1; i++){ s1[i - 1] = s1[i]; } s1[lens1 - 1] = '\0'; } // cout << s1 << endl; ll ans1 = ac.solve(s1); ll ans2 = ac.solve(s2); ll ans = ((ans2 - ans1) % MOD + MOD) % MOD; printf("%lld\n", ans); } return 0; }