zju3545

AC自动机+状态压缩DP

注意:相同的串可能出现多次,如果匹配成功则将各次权值加和。

#include <cstdio>
#include <queue>
#include <cstring>
using namespace std;

#define D(x) 

const int MAX_N = 15;
const int MAX_LEN = 105;
const int MAX_CHILD_NUM = 4;
const int MAX_NODE_NUM = MAX_LEN * MAX_N;

//1.init() 2.insert() 3.build() 4.query()
struct Trie
{
    int next[MAX_NODE_NUM][MAX_CHILD_NUM];
    int fail[MAX_NODE_NUM];
    int count[MAX_NODE_NUM];
    int node_cnt;
    bool vis[MAX_NODE_NUM]; //set it to false
    int root;

    void init()
    {
        node_cnt = 0;
        root = newnode();
        memset(vis, 0, sizeof(vis));
    }

    int newnode()
    {
        for (int i = 0; i < MAX_CHILD_NUM; i++)
            next[node_cnt][i] = -1;
        count[node_cnt++] = 0;
        return node_cnt - 1;
    }

    int get_id(char a)
    {
        if (a == 'A')
            return 0;
        if (a == 'T')
            return 1;
        if (a == 'C')
            return 2;
        return 3;
    }

    void insert(char buf[], int id)
    {
        int now = root;
        for (int i = 0; buf[i]; i++)
        {
            int id = get_id(buf[i]);
            if (next[now][id] == -1)
                next[now][id] = newnode();
            now = next[now][id];
        }
        count[now] |= (1 << id);
    }

    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for (int i = 0; i < MAX_CHILD_NUM; i++)
            if (next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while (!Q.empty())
        {
            int now = Q.front();
            Q.pop();
            for (int i = 0; i < MAX_CHILD_NUM; i++)
                if (next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    count[next[now][i]] |= count[fail[next[now][i]]];
                    Q.push(next[now][i]);
                }
        }
    }

    int query(char buf[])
    {
        int now = root;
        int res = 0;

        memset(vis, 0, sizeof(vis));
        for (int i = 0; buf[i]; i++)
        {
            now = next[now][get_id(buf[i])];
            int temp = now;
            while (temp != root && !vis[temp])
            {
                res += count[temp];
                 // optimization: prevent from searching this fail chain again.
                //also prevent matching again.
                vis[temp] = true;
                temp = fail[temp];
            }
        }
        return res;
    }

    void debug()
    {
        for(int i = 0;i < node_cnt;i++)
        {
            printf("id = %3d,fail = %3d,end = %3d,chi = [",i,fail[i],count[i]);
            for(int j = 0;j < MAX_CHILD_NUM;j++)
                printf("%2d",next[i][j]);
            printf("]\n");
        }
    }
}ac;

const int MAX_STATUS = (1 << 10) + 20;

int n, len;
char st[MAX_LEN];
int w[MAX_N];
bool dp[2][MAX_NODE_NUM][MAX_STATUS];

int cal(int status)
{
    int ret = 0;
    for (int i = 0; i < n; i++)
    {
        if (status & (1 << i))
            ret += w[i];
    }
    return ret;
}

int work()
{
    int ret = -1;
    memset(dp, 0, sizeof(dp));
    dp[0][ac.root][0] = true;
    for (int i = 0; i < len; i++)
    {
        for (int j = 0; j < ac.node_cnt; j++)
            for (int status = 0; status < (1 << n); status++)
                dp[(i + 1) & 1][j][status] = false;
        D(printf("%d\n", dp[(i + 1) & 1][2][0]));
        for (int j = 0; j < ac.node_cnt; j++)
        {
            for (int status = 0; status < (1 << n); status++)
            {
                if (!dp[i & 1][j][status])
                    continue;
                for (int k = 0; k < 4; k++)
                {
                    int v = ac.next[j][k];
                    dp[(i + 1) & 1][v][status | ac.count[v]] = true;
                    D(printf("%d %d\n", j, status));
                    D(printf("%d %d %d %d\n", (i + 1) & 1, v, status | ac.count[v], dp[(i + 1) & 1][0][1044]));
                }
            }
        }
    }
    for (int i = 0; i < ac.node_cnt; i++)
        for (int status = 0; status < (1 << n); status++)
        {
            if (dp[len & 1][i][status])
            {
                if (dp[len & 1][i][0] && status == 0)
                {
                    D(printf("*%d %d\n", i, ac.count[i]));
                }
                ret = max(ret, cal(status));
            }
        }
    return ret;
}

int main()
{
    while (scanf("%d%d", &n, &len) != EOF)
    {
        ac.init();
        for (int i = 0; i < n; i++)
        {
            scanf("%s%d", st, &w[i]);
            ac.insert(st, i);
        }
        ac.build();
        int ans = work();
        if (ans < 0)
            puts("No Rabbit after 2012!");
        else
            printf("%d\n", ans);
    }
    return 0;
}
View Code

 

posted @ 2015-03-27 21:19  金海峰  阅读(112)  评论(0编辑  收藏  举报