bzoj3864 Hero meet devil(dp套dp)
题面
题目大意:
给出一个模式串\(S(|S|≤15)\) 问存在多少个长为\(m(m≤1000)\) 的字符串T满足\(LCS(S,T)=x(0≤x≤|S|)\) 输出\(|S|+1\)个结果\((mod 1e9+7)\) (\(|S|\)表示字符串S的长度,字符集为\(A,T,C,G\)四个字母)
题解
朴素\(lcs\)的\(dp\)
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
if (a[i] == b[j]) f[i][j] = f[i-1][j-1]+1;
else f[i][j] = max(f[i-1][j], f[i][j-1], f[i-1][j-1]);
我们能发现
-
\(f[i][j]\)和\(f[i][j-1]\),\(f[i][j+1]\)最多相差\(1\)
-
\(|S| ≤ 15\)
我们可以把\(j\)那一维的差分数组状压一下
然后呢?
设\(f[i][S]\)表示在第\(i\)个位置,此时\(lcs\)的状态为\(S\)的方案数
预处理出 \(nxt[S][A/C/G/T]\) 为\(S\)状态下,添加\(A/C/G/T\)后分别的状态
然后就有
\(f[i+1][nxt[s][k]] += f[i][s]\)
至于预处理,我们把状压还原出来
模拟朴素\(dp\)一遍,再压回去
Code
#include<bits/stdc++.h>
#define LL long long
#define RG register
using namespace std;
template<class T> inline void read(T &x) {
x = 0; RG char c = getchar(); bool f = 0;
while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
x = f ? -x : x;
return ;
}
template<class T> inline void write(T x) {
if (!x) {putchar(48);return ;}
if (x < 0) x = -x, putchar('-');
int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
const int N = 1001, Mod = 1e9+7;
char S[16], SS[5] = {"ACGT"};
int a[16], f[N][(1<<15)+1], nxt[(1<<15)+1][5], n, len, limit, ans[16];
int tmp[2][16];
int solve(int s, int ch) {
int ret = 0;
memset(tmp, 0, sizeof(tmp));
for (int i = 0; i < n; i++) tmp[0][i+1] = tmp[0][i]+((s>>i)&1);
for (int i = 1; i <= n; i++) {
int mx = 0;
if (a[i] == ch) mx = tmp[0][i-1]+1;
mx = max(max(mx, tmp[0][i]), tmp[1][i-1]);
tmp[1][i] = mx;
}
for (int i = 0; i < n; i++) ret += (1<<i)*(tmp[1][i+1]-tmp[1][i]);
return ret;
}
int main() {
//freopen(".in", "r", stdin);
//freopen(".out", "w", stdout);
int q; read(q);
while (q--) {
memset(f, 0, sizeof(f)); memset(ans, 0, sizeof(ans));
scanf("%s", S+1);
n = strlen(S+1); limit = 1<<n;
for (int i = 1; i <= n; i++)
for (int j = 0; j < 4; j++)
if (S[i] == SS[j]) {a[i] = j+1; break;}
read(len);
for (int s = 0; s < limit; s++)
for (int j = 1; j <= 4; j++)
nxt[s][j] = solve(s, j);
f[0][0] = 1;
for (int i = 0; i < len; i++)
for (int s = 0; s < limit; s++)
for (int k = 1; k <= 4; k++)
(f[i+1][nxt[s][k]] += f[i][s]) %= Mod;
for (int s = 0; s < limit; s++) {
int cnt = __builtin_popcount(s);
(ans[cnt] += f[len][s]) %= Mod;
}
for (int i = 0; i <= n; i++)
printf("%d\n", ans[i]);
}
return 0;
}