bzoj3530

AC自动机+数位dp

先建立AC自动机trie图,然后在上面跑dp。

dp[i][j][0]表示到了天际线,也就是j==n时<n的方案数

dp[i][j][1]表示到了天际线并且==n的方案数

f[i][j]表示j<n时合法的方案数

因为如果位数比n小,那么我们不用在乎每位选什么,只要没有前导0就行了,因为位数小肯定比原数小,那么我们就dp一下,计算不经过危险节点的方案数

但是如果到了天际线,也就是j==n,那么这时候要小心统计,具体转移看代码,就是分这一位比原数对应位的大小,然后所有的dp[i][j][0/1],f[i][j]加起来就是答案

对于前导零的处理,就是dp[i][0][0/1]=f[i][1]=0,i代表第一个0节点,这样就不会统计前导0了 dp数组只有在天际线的时候才统计

1A~

#include<bits/stdc++.h>
using namespace std;
const int N = 1510, mod = 1000000007;
int m, n;
char s[N], num[N];
struct ac_automation {
    int cnt, root, pos;
    int ch[N][11], fail[N], danger[N];
    long long dp[N][N][2], f[N][N];
    void insert(char s[])
    {
        int now = root, len = strlen(s);
        for(int i = 0; i < len; ++i) 
        {
            int &v = ch[now][s[i] - '0'];
            if(v == 0) v = ++cnt;
            now = v;
        }
        danger[now] = 1;
    }
    void construct_fail()
    {
        queue<int> q;
        for(int i = 0; i < 10; ++i) if(ch[root][i] == 0) 
        {
            ++cnt;
            if(i == 0) pos = cnt;
            ch[root][i] = cnt;
        }
        for(int i = 0; i < 10; ++i) q.push(ch[root][i]);
        while(!q.empty())
        {
            int u = q.front();
            q.pop();
            for(int i = 0; i < 10; ++i)
            {
                int &v = ch[u][i];
                if(v == 0) v = ch[fail[u]][i];
                else
                {
                    fail[v] = ch[fail[u]][i];
                    danger[v] |= danger[fail[v]];
                    q.push(v);
                }
            }
        }
    }
    void solve()
    {
        for(int i = 1; i < 10; ++i) 
        {
            int x = ch[root][i];
            if(danger[x]) continue;
            if(i < num[1] - '0') dp[x][1][0] = 1;
            else if(i == num[1] - '0') dp[x][1][1] = 1;
            f[x][1] = 1;
        }
        long long ans = 0;
        for(int j = 1; j < n; ++j)
            for(int i = 1; i <= cnt; ++i) if(!danger[i])
                for(int k = 0; k < 10; ++k)
                {
                    int x = ch[i][k];
                    if(danger[x]) continue;
                    if(k < num[j + 1] - '0') dp[x][j + 1][0] = (dp[x][j + 1][0] + dp[i][j][0] + dp[i][j][1]) % mod;
                    else if(k == num[j + 1] - '0')
                    {
                        dp[x][j + 1][0] = (dp[x][j + 1][0] + dp[i][j][0]) % mod;
                        dp[x][j + 1][1] = (dp[x][j + 1][1] + dp[i][j][1]) % mod;
                    }
                    else dp[x][j + 1][0] = (dp[x][j + 1][0] + dp[i][j][0]) % mod;
                    if(j + 1 < n) f[x][j + 1] = (f[x][j + 1] + f[i][j]) % mod;
                }
        for(int i = 1; i <= cnt; ++i) 
            for(int j = 1; j <= n; ++j)
            {
                if((i == pos && j == 1) || danger[i]) continue;
                ans = (ans + f[i][j]) % mod;
                if(j == n) ans = (ans + dp[i][j][0] + dp[i][j][1]) % mod;
            }
        printf("%lld\n", ans);        
    }
} ac;
int main()
{
    scanf("%s%d", num + 1, &m);
    n = strlen(num + 1);
    for(int i = 1; i <= m; ++i) scanf("%s", s), ac.insert(s);
    ac.construct_fail();
    ac.solve();
    return 0;
}
View Code

 

posted @ 2017-08-27 18:36  19992147  阅读(128)  评论(0编辑  收藏  举报