[BZOJ3864] Hero meet devil
Desprition
给出一个只由字母 \(A\),\(G\),\(T\),\(C\) 组成的字符串 \(S\) ,长度为 \(n\) ,对于每个 \(i\) \(\in\) \([0,n]\),问有多少个长度为 \(m\),仅含有 \(A\),\(G\),\(T\),\(C\) 的字符串 \(T\) 使得 \(S\) 与 \(T\) 的最长公共子序列长度为 \(i\) 。
Solution
先研究 \(LCS\) 的转移柿子
得到结论:
当 \(i\) 固定时,
也就是,\(lcs[i][j]\) 和 \(lcs[i][j - 1]\) 最多相差 \(1\) 且满足单调不减。
因此我们可以使用差分,又因为 \(\left|\ S \right| <= 15\),可以直接把差分后的\(lcs\) 状压起来。
定义 \(f[i][j]\): 当 \(lcs\) 状态为 \(i\) 时加上 字符 \(j\) 后的状态, \(dp[i][j]\): 长度为 \(i\) 时,状态为 \(j\) 的方案数。
首先预处理出 \(f[i][j]\),具体注释在代码里。
重点在于 \(dp\) 柿子的推导, 其实很简单啊。
你想嘛, \(dp[i][j]\) 的定义长度为 \(i\) 时,状态为 \(j\) 的方案数, 那肯定是从 \(dp[i - 1][k]\) 转移过来的。
那 \(k\) 怎么确定呢???
前面的 \(f[][]\) 不就是用来干这件事的吗?
直接枚举当前长度 \(i\), 长度为 \(i - 1\) 时的状态 \(j\) 以及 第 \(i\) 位的情况 \(k\)
那么当前状态就应该是 \(f[j][k]\) —— 在 \(j\) 后增加 字符\(k\) 的状态。
那么就可以得出式子啦~
最后答案的计算,因为将 \(lcs\) 差分了,所以 \(lcs\) 的实际长度就是 状压状态下 \(i\) 中 \(1\) 的个数。
另外的小细节在代码中啦~
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define il inline
const int N = (1 << 15) + 5, M = 1e3 + 5, mod = 1e9 + 7;
int T, n, m, num, t, a[20], f[N][5], dp[M][N], g[2][20], ans[20];
char s[N];
il int cnt(int x) { // 1 的个数
int res = 0;
while(x) {
res += x & 1;
x >>= 1;
}
return res;
}
il int solve(int x,int y) {
int res = 0;
memset(g,0,sizeof(g));//g[0/1] 加入字符前后的差分lcs
for(int i = 0; i < n; i ++) g[0][i + 1] = g[0][i] + ((x >> i) & 1); // 用 g 表示将 lcs 差分
for(int i = 1; i <= n; i ++) {
if(a[i] == y) g[1][i] = g[0][i - 1] + 1;//当前位相同的话, 前 i-1 位 + 1 加上当前位的贡献
g[1][i] = max(max(g[1][i],g[0][i]),g[1][i - 1]); //lcs通用求法
}
for(int i = 0; i < n; i ++) res += (1 << i) * (g[1][i + 1] - g[1][i]); // 状压
return res;
}
il void read(int &x) {
x = 0; char s = getchar();
while(s < '0' || s > '9') s = getchar();
while(s <= '9' && s >= '0') x = x * 10 + s - '0', s = getchar();
}
il void write(int x) {
if(x < 0) x = -x, putchar('-');
if(x > 9) write(x / 10), x %= 10;
putchar(x + '0');
}
int main() {
read(T);
while(T--) {
memset(ans,0,sizeof(ans));
memset(dp,0,sizeof(dp));
scanf("%s",s + 1), read(m);
n = strlen(s + 1), num = 1 << n;
for(int i = 1; i <= n; i ++) { //处理字符
if(s[i] == 'A') a[i] = 1;
else if(s[i] == 'G') a[i] = 2;
else if(s[i] == 'T') a[i] = 3;
else a[i] = 4;
}
for(int i = 0; i < num; i ++) {
for(int j = 1; j <= 4; j ++) f[i][j] = solve(i,j); // 枚举每种状态加上不同字符的情况
}
dp[0][0] = 1;//边界
for(int i = 1; i <= m; i ++) {
for(int j = 0; j < num; j ++) {
for(int k = 1; k <= 4; k ++) dp[i][f[j][k]] = (dp[i][f[j][k]] + dp[i - 1][j]) % mod; // 前面有讲解
}
}
for(int i = 0; i < num; i ++) {
t = cnt(i);
ans[t] = (ans[t] + dp[m][i]) % mod; // 累计答案
}
for(int i = 0; i <= n; i ++) write(ans[i]), putchar('\n');
}
return 0;
}