ZOJ 3494 BCD Code(AC自动机 + 数位DP)题解

题意:每位十进制数都能转化为4位二进制数,比如9是1001,127是 000100100111,现在问你,在L到R(R <= $10^{200}$)范围内,有多少数字的二进制表达式不包含模式串。

思路:显然这是一道很明显的数位DP + AC自动机的题目。但是你要是直接把数字转化为二进制,然后在Trie树上数位DP你会遇到一个问题,以为转化为二进制后,前导零变成了四位000,那么你在DP的时候还要考虑前4位是不是都是000那样就要重新跑Trie树,显然这样是很菜(不会)的。那么肯定是想办法要变成十进制跑Trie树。

那我们就预处理出一个bcd[i][j]表示在Trie树上i节点走向数字j可不可行,这样就行了。

代码:

#include<set>
#include<map>
#include<queue>
#include<cmath>
#include<string>
#include<cstdio>
#include<vector>
#include<cstring>
#include <iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 2000 + 5;
const int M = 50 + 5;
const ull seed = 131;
const int INF = 0x3f3f3f3f;
const int MOD = 1000000009;
int n, m;
int bit[205], pos;
ll dp[500][maxn];
int bcd[maxn][10];
struct Aho{
    struct state{
        int next[10];
        int fail, cnt;
    }node[maxn];
    int size;
    queue<int> q;

    void init(){
        size = 0;
        newtrie();
        while(!q.empty()) q.pop();
    }

    int newtrie(){
        memset(node[size].next, 0, sizeof(node[size].next));
        node[size].cnt = node[size].fail = 0;
        return size++;
    }

    void insert(char *s){
        int len = strlen(s);
        int now = 0;
        for(int i = 0; i < len; i++){
            int c = s[i] - '0';
            if(node[now].next[c] == 0){
                node[now].next[c] = newtrie();
            }
            now = node[now].next[c];
        }
        node[now].cnt = 1;

    }

    void build(){
        node[0].fail = -1;
        q.push(0);

        while(!q.empty()){
            int u = q.front();
            q.pop();
            if(node[node[u].fail].cnt && u) node[u].cnt |= node[node[u].fail].cnt;
            for(int i = 0; i < 10; i++){
                if(!node[u].next[i]){
                    if(u == 0)
                        node[u].next[i] = 0;
                    else
                        node[u].next[i] = node[node[u].fail].next[i];
                }
                else{
                    if(u == 0) node[node[u].next[i]].fail = 0;
                    else{
                        int v = node[u].fail;
                        while(v != -1){
                            if(node[v].next[i]){
                                node[node[u].next[i]].fail = node[v].next[i];
                                break;
                            }
                            v = node[v].fail;
                        }
                        if(v == -1) node[node[u].next[i]].fail = 0;
                    }
                    q.push(node[u].next[i]);
                }
            }
        }
    }

    ll dfs(int pos, int st, bool Max, bool lead){
        if(pos == -1) return 1;
        if(!Max && !lead && dp[pos][st] != -1) return dp[pos][st];
        int top = Max? bit[pos] : 9;
        ll ans = 0;
        for(int i = 0; i <= top; i++){
            if(lead && i == 0 && pos != 0){
                ans = (ans + dfs(pos - 1, 0, Max && i == top, lead && i == 0)) % MOD;
                continue;
            }
            if(bcd[st][i] == -1) continue;
            ans = (ans + dfs(pos - 1, bcd[st][i], Max && i == top, lead && i == 0)) % MOD;
        }
        if(!Max && !lead) dp[pos][st] = ans;
        return ans;
    }

    ll solve(char *s){
        pos = 0;
        int len = strlen(s);
        for(int i = len - 1; i >= 0; i--){
            bit[pos++] = s[i] - '0';
        }
        return dfs(pos - 1, 0, true, true);
    }

    char num[10][5] = {"0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111", "1000", "1001"};
    void init_bcd(){
        memset(bcd, 0, sizeof(bcd));
        for(int i = 0; i < size; i++){
            for(int j = 0; j < 10; j++){
                int v = i;
                for(int k = 0; k < 4; k++){
                    v = node[v].next[num[j][k] - '0'];
                    if(node[v].cnt){
                        bcd[i][j] = -1;
                        break;
                    }
                }
                if(bcd[i][j] != -1) bcd[i][j] = v;
            }
        }
    }

}ac;

char s1[205], s2[205];
int main(){
    int T;
    scanf("%d", &T);
    while(T--){
        memset(dp, -1, sizeof(dp));
        scanf("%d", &n);
        ac.init();
        for(int i = 0; i < n; i++){
            scanf("%s", s1);
            ac.insert(s1);
        }
        ac.build();
        ac.init_bcd();

        scanf("%s%s", s1, s2);
        int lens1 = strlen(s1);
        int pp = lens1 - 1;
        while(s1[pp] == '0'){
            s1[pp] = '9';
            pp--;
        }
        s1[pp]--;
        if(s1[0] == '0' && lens1 > 1){
            for(int i = 1; i < lens1; i++){
                s1[i - 1] = s1[i];
            }
            s1[lens1 - 1] = '\0';
        }
//        cout << s1 << endl;
        ll ans1 = ac.solve(s1);
        ll ans2 = ac.solve(s2);
        ll ans = ((ans2 - ans1) % MOD + MOD) % MOD;
        printf("%lld\n", ans);
    }
    return 0;
}

 

posted @ 2019-07-17 11:13  KirinSB  阅读(248)  评论(0编辑  收藏  举报