POJ 1625 Censored ( Trie图 && DP && 高精度 )

题意 : 给出 n 个单词组成的字符集 以及 p 个非法串,问你用字符集里面的单词构造长度为 m 的单词的方案数有多少种?

分析 :先构造出 Trie 图方便进行状态转移,这与在 POJ 2278 中的步骤是一样的,只不过最后的DP状态转移方式 2778 是利用了矩阵进行转移的,那是因为需要构造的串的长度非常长!只能利用矩阵转移。但是这道题需要构造的串的长度最多也就只有 50 ,可以利用普通的DP方法进行转移。我们定义 DP[i][j] 为以长度为 i 以字符 j 为结尾的串的种类数是多少,那么状态转移方程很显然就是 DP[i+1][k] += DP[i][j] * G[j][k] 这个方程表示现在 k 到 j 有一条边并且从k 走一步可以到 j 的方案数是 G[j][k] ( Trie 图构建出来的 ),那么现在 DP[i+1][k] 就很明显可以从 DP[i][j] 转移而来,DP的初始状态为 DP[0][0] = 0 && DP[0][i] = 0。

注意 : 

① 因为没有要求对答案进行求模运算,答案可能很大,因为如果 p = 0,而n 和 m 都达到最大的50,那么答案就是 50^50,所以需要用到高精度。

② 字符可能有超过 128 的,也就是有负数情况,用map转化

#include<string.h>
#include<stdio.h>
#include<iostream>
#include<queue>
#include<map>
using namespace std;
const int Max_Tot = 111;
const int Letter = 256;
int G[111][111], n;
map<int, int> mp;
struct bign{
    #define MAX_B (100)
    #define MOD (10000)
    int a[MAX_B], n;
    bign() { a[0] = 0, n = 1; }
    bign(int num)
    {
        n = 0;
        do {
            a[n++] = num % MOD;
            num /= MOD;
        } while(num);
    }
    bign& operator= (int num)
    { return *this = bign(num); }
    bign operator+ (const bign& b) const
    {
        bign c = bign();
        int cn = max(n, b.n), d = 0;
        for(int i = 0, x, y; i < cn; i++)
        {
            x = (n > i) ? a[i] : 0;
            y = (b.n > i) ? b.a[i] : 0;
            c.a[i] = (x + y + d) % MOD;
            d = (x + y + d) / MOD;
        }
        if(d) c.a[cn++] = d;
        c.n = cn;
        return c;
    }
    bign& operator+= (const bign& b)
    {
        *this = *this + b;
        return *this;
    }
    bign operator* (const bign& b) const
    {
        bign c = bign();
        int cn = n + b.n, d = 0;
        for(int i = 0; i <= cn; i++)
            c.a[i] = 0;
        for(int i = 0; i < n; i++)
        for(int j = 0; j < b.n; j++)
        {
            c.a[i + j] += a[i] * b.a[j];
            c.a[i + j + 1] += c.a[i + j] / MOD;
            c.a[i + j] %= MOD;
        }
        while(cn > 0 && !c.a[cn-1]) cn--;
        if(!cn) cn++;
        c.n = cn;
        return c;
    }
    friend ostream& operator<< (ostream& _cout, const bign& num)
    {
        printf("%d", num.a[num.n - 1]);
        for(int i = num.n - 2; i >= 0; i--)
            printf("%04d", num.a[i]);
        return _cout;
    }
};
struct Aho{
    struct StateTable{
        int Next[Letter];
        int fail, flag;
    }Node[Max_Tot];
    int Size;
    queue<int> que;

    inline void init(){
        while(!que.empty()) que.pop();
        memset(Node[0].Next, 0, sizeof(Node[0].Next));
        Node[0].fail = Node[0].flag = 0;
        Size = 1;
    }

    inline void insert(char *s){
        int now = 0;
        for(int i=0; s[i]; i++){
            int idx = mp[s[i]];
            if(!Node[now].Next[idx]){
                memset(Node[Size].Next, 0, sizeof(Node[Size].Next));
                Node[Size].fail = Node[Size].flag = 0;
                Node[now].Next[idx] = Size++;
            }
            now = Node[now].Next[idx];
        }
        Node[now].flag = 1;
    }

    inline void BuildFail(){
        Node[0].fail = 0;
        for(int i=0; i<n; i++){
            if(Node[0].Next[i]){
                Node[Node[0].Next[i]].fail = 0;
                que.push(Node[0].Next[i]);
            }else Node[0].Next[i] = 0;///必定指向根节点
        }
        while(!que.empty()){
            int top = que.front(); que.pop();
            if(Node[Node[top].fail].flag) Node[top].flag = 1;
            for(int i=0; i<n; i++){
                int &v = Node[top].Next[i];
                if(v){
                    que.push(v);
                    Node[v].fail = Node[Node[top].fail].Next[i];
                }else v = Node[Node[top].fail].Next[i];
            }
        }
    }

    inline void BuildMap(){
        for(int i=0; i<Size; i++)
            for(int j=0; j<Size; j++)
                G[i][j] = 0;

        for(int i=0; i<Size; i++){
            for(int j=0; j<n; j++){
                if(!Node[ Node[i].Next[j] ].flag)
                    G[i][Node[i].Next[j]]++;
            }
        }
    }
}ac;

#define MAX_M (55)
bign dp[MAX_M][Max_Tot];

char s[51];
int main(void)
{
    int m, p;
    while(~scanf("%d %d %d\n", &n, &m, &p)){
        mp.clear();
        gets(s);
        int len = strlen(s);
        for(int i=0; i<len; i++)
            mp[s[i]] = i;

        ac.init();
        for(int i=0; i<p; i++){
            gets(s);
            ac.insert(s);
        }
        ac.BuildFail();
        ac.BuildMap();

        for(int i=0; i<=m; i++)
            for(int j=0; j<ac.Size; j++)
                dp[i][j] = bign();

        dp[0][0] = 1;
        for(int i=0; i<m; i++)
        for(int j=0; j<ac.Size; j++){
            for(int k=0; k<ac.Size; k++){
                dp[i+1][k] += dp[i][j] * G[j][k];
            }
        }

        bign ans = bign();

        for(int i=0; i<ac.Size; i++)
            ans += dp[m][i];

        cout<<ans<<endl;
    }
    return 0;
}
View Code

 

posted @ 2019-10-01 16:13  shuai_hui  阅读(152)  评论(0编辑  收藏  举报