HDU 4899 Hero meet devil dp套dp

题目:http://acm.hdu.edu.cn/showproblem.php?pid=4899

题意:字符集中只有“ATCG”四种字符,现在给你一个字符串S。问长度为m的所有字符串T中,最长公共子序列lcs(S, T) = i的个数,其中 i 取遍0~|S|。|S|为S的长度。

 

如果我们直接用state表示lcs状态进行状压dp,那么会出现许多重复值

为了避免重复,我们需要一个newstate,表示最小的满足lcs=i的状态

那么如何表示最小呢,在state的基础上新加一个字符,看看是否能改变lcs的值,如果改变了,那么改变后的新状态就是一个最小的状态

例如只在state的最后一位新匹配了一个字符

state  010010 -> pre  011122 -> lcs  011123 ->newstate  010011

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<cmath>
#include<algorithm>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
using namespace std;
const int N=20;
const int mod=1e9+7;
char s[N],ch[5]="ACGT";
int lcs[N],pre[N];
int newstate[1<<15][4];
int dp[2][1<<15];
int ans[N];
int n,m,maxx;
inline void add(int &x,int y)
{
    x+=y;
    if (x>=mod) x-=mod;
}
void init()
{
    for(int state=0;state<maxx;state++)
    {
        pre[0]=0;
        for(int i=1;i<=n;i++) pre[i]=pre[i-1]+((state>>(i-1))&1);
        for(int k=0;k<4;k++)
        {
            for(int i=1;i<=n;i++)
            {
                if (s[i]==ch[k]) lcs[i]=pre[i-1]+1;
                else lcs[i]=max(pre[i],lcs[i-1]);
            }
            newstate[state][k]=0;
            for(int i=1;i<=n;i++)
                newstate[state][k]|=((lcs[i]!=lcs[i-1])<<(i-1));
        }
    }
}
int f(int x)
{
    int s=0;
    while(x)
    {
        s+=(x&1);
        x>>=1;
    }
    return s;
}
void solve()
{
    int now=0,last=1;
    memset(dp[now],0,sizeof(dp[now]));
    dp[now][0]=1;
    for(int i=0;i<m;i++)
    {
        swap(now,last);
        memset(dp[now],0,sizeof(dp[now]));
        for(int state=0;state<maxx;state++)
            if (dp[last][state])
            for(int k=0;k<4;k++)
                add(dp[now][newstate[state][k]],dp[last][state]);
    }
    memset(ans,0,sizeof(ans));
    for(int state=0;state<maxx;state++)
        add(ans[f(state)],dp[now][state]);
    for(int i=0;i<=n;i++)
        printf("%d\n",ans[i]);
}
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%s%d",s+1,&m);
        n=strlen(s+1);
        maxx=1<<n;
        init();
        solve();
    }
    return 0;
}

  

posted @ 2017-09-01 11:25  BK_201  阅读(120)  评论(0编辑  收藏  举报