poj3341

AC自动机,用40^4 * 50 * 10的空间进行dp。

最大的难点在于hash。

hash一个数列f,数列中的每一位都有一个上限g,即f[i]<=g[i]。

那么可以将该数列hash为这样一个整数,这个整数的每一个位的进制都不同,第i位的进制是g[i] + 1,即第i位满g[i]+1则可进位。(当然由于g[i]是该位的上限,所以永远不可能进位)

用p[i]表示(g[0]+1)*(g[1]+1)*...*(g[i - 1]+1)。那么最终f被hash的结果是p[0]*f[0]+p[1]*f[1]+...。

#include <cstdio>
#include <algorithm>
#include <queue>
#include <map>
#include <cstring>
using namespace std;

#define D(x) 

const int MAX_LEN = 50;
const int MAX_N = 55;
const int MAX_NODE_NUM = 12 * MAX_N;
const int MAX_CHILD_NUM = 4;

//1.init() 2.insert() 3.build() 4.query()
struct Trie
{
    int next[MAX_NODE_NUM][MAX_CHILD_NUM];
    int fail[MAX_NODE_NUM];
    int count[MAX_NODE_NUM];
    int node_cnt;
    bool vis[MAX_NODE_NUM]; //set it to false
    int root;

    void init()
    {
        node_cnt = 0;
        root = newnode();
        memset(vis, 0, sizeof(vis));
    }

    int newnode()
    {
        for (int i = 0; i < MAX_CHILD_NUM; i++)
            next[node_cnt][i] = -1;
        count[node_cnt++] = 0;
        return node_cnt - 1;
    }

    int get_id(char a)
    {
        if (a == 'A')
            return 0;
        if (a == 'C')
            return 1;
        if (a == 'T')
            return 2;
        return 3;
    }

    void insert(char buf[])
    {
        int now = root;
        for (int i = 0; buf[i]; i++)
        {
            int id = get_id(buf[i]);
            if (next[now][id] == -1)
                next[now][id] = newnode();
            now = next[now][id];
        }
        count[now]++;
    }

    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for (int i = 0; i < MAX_CHILD_NUM; i++)
            if (next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while (!Q.empty())
        {
            int now = Q.front();
            Q.pop();
            for (int i = 0; i < MAX_CHILD_NUM; i++)
                if (next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    count[next[now][i]] += count[fail[next[now][i]]];
                    Q.push(next[now][i]);
                }
        }
    }

    int query(char buf[])
    {
        int now = root;
        int res = 0;

        memset(vis, 0, sizeof(vis));
        for (int i = 0; buf[i]; i++)
        {
            now = next[now][get_id(buf[i])];
            int temp = now;
            while (temp != root && !vis[temp])
            {
                res += count[temp];
                 // optimization: prevent from searching this fail chain again.
                //also prevent matching again.
                vis[temp] = true;
                temp = fail[temp];
            }
        }
        return res;
    }

    void debug()
    {
        for(int i = 0;i < node_cnt;i++)
        {
            printf("id = %3d,fail = %3d,end = %3d,chi = [",i,fail[i],count[i]);
            for(int j = 0;j < MAX_CHILD_NUM;j++)
                printf("%2d",next[i][j]);
            printf("]\n");
        }
    }
}ac;

char st[MAX_LEN];
int n;
int num[4];
int num2[4];
int dp[15000][505];

int myhash(int f[])
{
    int ret = 0;
    for (int i = 0; i < 4; i++)
    {
        ret += f[i] * num[i];
    }
    return ret;
}

int work()
{
    int temp[4] = {0};
    for (int i = 0; st[i]; i++)
    {
        temp[ac.get_id(st[i])]++;
    }
    num[3] = 1;
    for (int i = 2; i >= 0; i--)
    {
        num[i] = (temp[i + 1] + 1) * num[i + 1];
    }
    memset(dp, -1, sizeof(dp));
    int f[4];
    int g[4];
    int ret = 0;
    for (f[0] = 0; f[0] <= temp[0]; f[0]++)
        for (f[1] = 0; f[1] <= temp[1]; f[1]++)
            for (f[2] = 0; f[2] <= temp[2]; f[2]++)
                for (f[3] = 0; f[3] <= temp[3]; f[3]++)
                {
                    for (int u = 0; u < ac.node_cnt; u++)
                    {
                        int h = myhash(f);
                        for (int j = 0; j < 4; j++)
                        {
                            g[j] = f[j];
                        }
                        int temp2 = 0;
                        for (int j = 0; j < 4; j++)
                        {
                            g[j]--;
                            int h2 = myhash(g);
                            int v = ac.next[u][j];
                            if (g[j] >= 0)
                            {
                                temp2 = max(temp2, dp[h2][v]);
                                D(printf("\t\t\t%d %d %d %d %d %d %d\n", g[0], g[1], g[2], g[3], v, temp2, h2));
                            }
                            g[j]++;
                        }
                        dp[h][u] = temp2 + ac.count[u];
                        ret = max(ret, dp[h][u]);
                        D(printf("%d %d %d %d %d %d %d\n", f[0], f[1], f[2], f[3], u, dp[h][u], h));
                    }
                }
    return dp[myhash(temp)][ac.root];
}

int main()
{
    int t = 0;
    while (scanf("%d", &n), n)
    {
        ac.init();
        for (int i = 0; i < n; i++)
        {
            scanf("%s", st);
            ac.insert(st);
        }
        ac.build();
        scanf("%s", st);
        printf("Case %d: %d\n", ++t, work());
    }
    return 0;
}
View Code

 

posted @ 2015-03-30 21:19  金海峰  阅读(151)  评论(0编辑  收藏  举报