hdu4758 hdu2825 hdu4057 AC自动机与状态压缩dp的结合

最近做到好几道关于AC自动机与状态压缩dp的结合的题,这里总结一下。

题目一般会给出m个字符串,m不超过10,然后求长度为len并且包含特定给出的字符串集合的字符串个数。

以HDU 4758为例:

把题意抽象为:给出两个字符串,且只包含两种字符 'R'、'D',现在求满足下列条件的字符串个数:

1、字符串必须包含上述两个字符串。

2、字符串长度为(m+n),其中包含n个'D',m个'R'。

如果不用AC自动机来做,这道题还真没法做了,因为不管怎样都找不到正确的dp状态转移方程。

而如果引入AC自动机,把在AC自动机上的结点当做dp的一个维度的状态,那么问题就可解了。

dp[c][zt][i][j]:c表示当前状态的字符串对应于AC自动机上的结点,zt表示给定字符串取舍情况的压缩状态,i表示'D'的个数,j表示'R'的个数。

那么dp[c][zt][i][j]表示当前状态字符串的个数。

循环到dp[c][zt][i][j]时,其实dp[c][zt][i][j]已经被计算出来了,然后遍历trie树中c的所有子节点,计算它们的dp值。

最外层循环应该是字符串长度的循环,循环次数是题目要求的字符串长度,第二层循环是trie树中的所有节点,第三层是字符串取舍状态,最后是遍历c节点的所有子节点(说是子节点,其实是对c节点的下一个字符进行遍历,需要使用fail指针)。

c节点并不代表某个具体的字符串,它是指所有能到达c节点的字符串,dp的值就是保存这些字符串中满足条件的字符串个数。

AC自动机的作用就是增加一个状态维度,使dp过程有足够的信息来转移状态。

#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int mod = 1000000007;
int ch[202][2],End[202],cur,fail[202],last[202];
void get_fail() {
    int now,tmpFail,Next;
    queue<int> q;
    for(int j=0;j<2;j++) {
        if(ch[0][j]) {
            q.push(ch[0][j]);
            fail[ch[0][j]] = 0;
            last[ch[0][j]] = 0;
        }
    }
    while(!q.empty()) {
        now = q.front();q.pop();
        for(int j=0;j<2;j++) {
            if(!ch[now][j]) continue;
            Next = ch[now][j];
            q.push(Next);
            tmpFail = fail[now];
            while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
            fail[Next] = ch[tmpFail][j];
            last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
        }
    }
}
int dp[202][4][102][102];//dp[c][zt][i][j]
int main() {
    int T,m,n;
    char str0[3][104];
    scanf("%d",&T);
    while(T--) {
        cur=1;
        scanf("%d%d",&m,&n);
        n++;m++;
        memset(End,0,sizeof(End));
        memset(ch,0,sizeof(ch));
        memset(last,0,sizeof(last));
        for(int i=1;i<=2;i++) {
            scanf("%s",str0[i]);
            int len = strlen(str0[i]);
            int now = 0;
            for(int j=0;j<len;j++) {
                if(str0[i][j]=='R') str0[i][j]=1;
                else str0[i][j]=0;
                if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                now = ch[now][str0[i][j]];
            }
            End[now] = i;
        }
        get_fail();


        memset(dp,0,sizeof(dp));
        dp[0][0][0][0]=1;
        for(int i=0;i<n;i++) //要特别注意这里内外循环顺序,必须把i、j循环放在外面
        for(int j=0;j<m;j++) {
            for(int c=0;c<cur;c++) {
                for(int zt=0;zt<=3;zt++){
                    if(dp[c][zt][i][j])
                    for(int k=0;k<2;k++) {
                        if(k==0&&i==n-1) continue;
                        else if(k==1&&j==m-1) continue;
                        int now=c;
                        while(now&&!ch[now][k]) now = fail[now];
                        now = ch[now][k];

                        int t=0;
                        if(End[now])
                            t = t|(1<<(End[now]-1));
                        int tmp = now;
                        while(last[tmp]) {
                            t = t|(1<<(End[last[tmp]]-1));
                            tmp = last[tmp];
                        }
                        if(k==0) {
                            dp[now][zt|t][i+1][j] += dp[c][zt][i][j];
                            if(dp[now][zt|t][i+1][j]>=mod) dp[now][zt|t][i+1][j]-=mod;
                        }
                        else if(k==1) {
                            dp[now][zt|t][i][j+1] += dp[c][zt][i][j];
                            if(dp[now][zt|t][i][j+1]>=mod) dp[now][zt|t][i][j+1]-=mod;
                        }
                    }
                }
            }
        }
        long long ans=0;
        for(int i=0;i<cur;i++) {
            ans+=dp[i][3][n-1][m-1];
            if(ans>=mod) ans-=mod;
        }
        printf("%I64d\n",ans);
    }
}

注意循环的内外顺序,一般情况下,字符串长度的循环都是放在外层,也就是说,一定要先计算出长度为i的所有字符串状态,才能计算长度为i+1的所有字符串状态。

类似的 HDU 2825 :给 m 个单词构成的集合,求至少包含 k 个单词且长度为n的字符串个数。

#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<queue>
using namespace std;
const int mod=20090717;
int ch[11*11][26],End[11*11],cur,fail[11*11],last[11*11];
char str0[12][12];
void get_fail() {
    int now,tmpFail,Next;
    queue<int> q;
    for(int j=0;j<26;j++) {
        if(ch[0][j]) {
            q.push(ch[0][j]);
            fail[ch[0][j]] = 0;
            last[ch[0][j]] = 0;
        }
    }
    while(!q.empty()) {
        now = q.front();q.pop();
        for(int j=0;j<26;j++) {
            if(!ch[now][j]) continue;
            Next = ch[now][j];
            q.push(Next);
            tmpFail = fail[now];
            while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
            fail[Next] = ch[tmpFail][j];
            last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
        }
    }
}
int dp[27][11*11][1055];
int main()
{
    int sum[1055];
    for(int I=0;I<(1<<10);I++) {
            sum[I]=0;
            int tmp=I;
            while(tmp) {
                if(tmp&1) sum[I]++;
                tmp>>=1;
            }
    }
    int n,m,k;
    while(scanf("%d%d%d",&n,&m,&k)!=EOF&&(n||m||k))
    {
        cur=1;
        int len[13];
        memset(End,0,sizeof(End));
        memset(ch,0,sizeof(ch));
        memset(last,0,sizeof(last));
        for(int i=1;i<=m;i++) {
            scanf("%s",str0[i]);
            len[i] = strlen(str0[i]);
            int now = 0;
            for(int j=0;j<len[i];j++) {
                str0[i][j]-='a';
                if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                now = ch[now][str0[i][j]];
                str0[i][j]+='a';
            }
            End[now] = i;

        }
        get_fail();
        memset(dp,0,sizeof(dp));
        dp[0][0][0]=1;
        int pre=0,zt=0;
        int ans=0;
        for(int i=0;i<n;i++) {
            for(int j=0;j<cur;j++) {
                for(int zt=0;zt<(1<<m);zt++) {
                    if(dp[i][j][zt]) {
                    for(int c=0;c<26;c++) {
                        int now = j;
                        while(now&&!ch[now][c]) now = fail[now];
                        now = ch[now][c];
                        int t=0;
                        if(End[now])
                            t = t|(1<<(End[now]-1));
                        int tmp = now;
                        while(last[tmp]) {
                            t = t|(1<<(End[last[tmp]]-1));
                            tmp = last[tmp];
                        }
                        dp[i+1][now][zt|t] += dp[i][j][zt];
                        if(dp[i+1][now][zt|t]>=mod) dp[i+1][now][zt|t]-=mod;
                    }
                    }
                }
            }
        }
        for(int I=0;I<(1<<m);I++) {
            if(sum[I]>=k) {
                for(int j=0;j<cur;j++){
                    ans+=dp[n][j][I];
                    if(ans>=mod) ans-=mod;

                }
            }
        }
        printf("%d\n",ans);
    }
}

 

HDU 4057:给出一些模式串,每个串有一定的价值,现在构造一个长度为M的串,问最大的价值为多少,每个模式串最多统计一次。

#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
int ch[11*102][4],End[11*102],cur,fail[11*102],last[11*102];
int w[11];
char str[102],str0[11][102];
void get_fail()
{
    int now,tmpFail,Next;
    queue<int> q;
    //用bfs生成fail
    //初始化队列
    for(int j=0; j<4; j++)
    {
        if(ch[0][j])
        {
            q.push(ch[0][j]);
            fail[ch[0][j]] = 0;
            last[ch[0][j]] = 0;
        }
    }
    while(!q.empty())
    {
        //从队列中拿出now
        //此时now中的fail、last已经算好了
        //下面计算的是ch[now][j]中的fail、last。
        now = q.front();
        q.pop();
        for(int j=0; j<4; j++)
        {
            if(!ch[now][j]) continue;
            Next = ch[now][j];
            q.push(Next);
            tmpFail = fail[now];
            while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
            fail[Next] = ch[tmpFail][j];
            last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
        }
    }
}
int dp[1029][11*102][2];
bool vis[1029][11*102][2];
int n,l,now,ans;
queue<int> quezt;
queue<int> quenow;
queue<int> quelen;
void bfs (int zt,int now0,int len)
{
    //printf("%d %d %d %d\n",zt,now0,len,dp[zt][now0][len%2]);
    //printf("%d\n",quezt.size());
    if(len==l) ans=max(ans,dp[zt][now0][l%2]);
    if(len==l+1) return;
    for(int i=0; i<4; i++)
    {
        int now=now0,temp=0;
        while(now&&!ch[now][i]) now = fail[now];
        now = ch[now][i];
        int newzt = zt;
        if(End[now])
        {
            if(((1<<(End[now]-1))|newzt)!=newzt) temp+=w[End[now]];
            newzt = (1<<(End[now]-1))|newzt;
        }
        int tmp = now;
        while(last[tmp])
        {
            if(End[last[tmp]])
            {
                if(((1<<(End[last[tmp]]-1))|newzt)!=newzt) temp+=w[End[last[tmp]]];
                newzt = (1<<(End[last[tmp]]-1))|newzt;
            }
            tmp = last[tmp];
        }
        if(newzt!=zt) {
            //printf("%d\n",temp);
            if(!vis[newzt][now][(len+1)%2]) dp[newzt][now][(len+1)%2]=dp[zt][now0][len%2]+temp;
            else dp[newzt][now][(len+1)%2]=max(dp[zt][now0][len%2]+temp,dp[newzt][now][(len+1)%2]);
        }
        else{
            if(!vis[zt][now][(len+1)%2]) dp[zt][now][(len+1)%2]=dp[zt][now0][len%2];
            else dp[zt][now][(len+1)%2]=max(dp[zt][now0][len%2],dp[zt][now][(len+1)%2]);
        }
        //dfs(newzt,now,len+1);
        if(!vis[newzt][now][(len+1)%2]) {
            quezt.push(newzt);
            quenow.push(now);
            quelen.push(len+1);
            vis[newzt][now][(len+1)%2]=true;
        }
    }
    //if(len==l) ans=max(ans,dp[zt][now0][l%2]);
}
int main()
{
    while(scanf("%d%d",&n,&l)!=EOF)
    {
        memset(dp,-1,sizeof(dp));
        memset(ch,0,sizeof(ch));
        memset(End,0,sizeof(End));
        memset(last,0,sizeof(last));
        cur = 1;
        int len;
        for(int i=1; i<=n; i++)
        {
            scanf("%s%d",str0[i],&w[i]);
            //puts(str0[i]);
            len = strlen(str0[i]);
            now = 0;
            for(int j=0; j<len; j++)
            {
                if(str0[i][j]=='A') str0[i][j]=0;
                if(str0[i][j]=='T') str0[i][j]=1;
                if(str0[i][j]=='G') str0[i][j]=2;
                if(str0[i][j]=='C') str0[i][j]=3;
                if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                now = ch[now][str0[i][j]];
                if(str0[i][j]==0) str0[i][j]='A';
                if(str0[i][j]==1) str0[i][j]='T';
                if(str0[i][j]==2) str0[i][j]='G';
                if(str0[i][j]==3) str0[i][j]='C';
            }
            End[now] = i;
        }
        //printf("%d\n",cur);
        get_fail();
        //printf("%d\n",cur);
        dp[0][0][0]=0;
        quezt.push(0);
        quenow.push(0);
        quelen.push(0);
        memset(vis,false,sizeof(vis));
        vis[0][0][0]=true;
        ans=-1;
        int pre=0;
        while(!quezt.empty()) {
            //if(quelen.front()!=pre) {
            //    for(int i=0;i<1029;i++)
            //    for(int j=0;j<11*102;j++) dp[i][j][pre%2]=0;
            //    pre=quelen.front();
            //}
            bfs(quezt.front(),quenow.front(),quelen.front());
            vis[quezt.front()][quenow.front()][quelen.front()%2]=false;
            quezt.pop();quenow.pop();quelen.pop();
        }
        if(ans==-1) puts("No Rabbit after 2012!");
        else printf("%d\n",ans);
    }
}
posted @ 2016-03-18 18:24  ZhMZ  阅读(229)  评论(0编辑  收藏  举报