[HAOI2016]字符合并
题意
link
有一个长度为 \(n(\leq 500)\) 的 01 串,你可以每次将相邻的 \(k(\leq 8)\) 个字符合并,得到一个新的字符并获得一定分数。
得到的新字符和分数由这 \(k\) 个字符确定。你需要求出你能获得的最大分数。
区间dp + 状态压缩
状态
这个数据范围一般都是区间dp,但是如果只设 \(f[l][r]\) 表示区间最大分数,这里区间难以合并,因为分开后的区间还能组合。
由于要求最大分数,并且分数都是正数,那么一个区间 \([l, r]\) 最后一定会剩下 <k 个数,如果记录这个状态就好合并了。
于是设 \(f[l][r][S]\) 表示区间 \([l, r]\) 最后变成 \(S\) 的最大分数。
初始状态
\(f[l][l][s[l]] = 0\) 。
其它非法状态设为 负无穷。
转移
考虑区间合并 设 \(len = (r - l) \mod (k - 1)\) 表示\([l, r - 1 - x(k - 1)](r - 1 - x(k - 1) >= l, x\in Z)\) 最后剩下的长度, 将最后长度为 \(len\) 的状态 和 为\(1\)的状态合并。
\(len = 0\) 时,其实是 \(k - 1\)。
1.区间合并。
\[f[l][r][S << 1 | 1] = \max_{p = r - 1 - x(k - 1)}\{f[l][p][S] + f[p][r - 1][1]\}
\]
\[f[l][r][S << 1 | 0] = \max_{p = r - 1 - x(k - 1)}\{f[l][p][S] + f[p][r - 1][0]\}
\]
- 当 \(len = k - 1\), 其实区间最终长度是 \(1\),那么还可以合并成新状态。
直接枚举所有情况。
\[f[l][r][0] = \max_{g[S] = 0}\{f[l][r][S] + val[S]\}
\]
\[f[l][r][1] = \max_{g[S] = 1}\{f[l][r][S] + val[S]\}
\]
分析
区间dp是 \(O(n^3)\), 枚举子集是 \(O(2^k)\)。
时间复杂度是 \(O(n^32^k)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAXN = 310;
const int INF = 0x7fffffff;
const int mod = 1000000007;
const double eps = 1e-9;
template <typename T>
void Read(T &x) {
x = 0; T f = 1; char a = getchar();
for(; a < '0' || '9' < a; a = getchar()) if (a == '-') f = -f;
for(; '0' <= a && a <= '9'; a = getchar()) x = (x * 10) + (a ^ 48);
x *= f;
}
inline void add(int &a, const int &b) {
a = a + b;
if (a >= mod) a -= mod;
if (a < 0) a += mod;
}
inline int mul(const int &a, const int &b) {
return 1ll * a * b % mod;
}
int qpow(int a, int b) {
int sum(1);
while(b) {
if (b & 1) sum = mul(sum, a);
a = mul(a, a);
b >>= 1;
}
return sum;
}
int n, k;
int a[MAXN], g[MAXN], val[MAXN];
ll f[MAXN][MAXN][257];
int main() {
cin >> n >> k;
for (int i = 1; i <= n; i ++)
cin >> a[i];
for (int i = 0; i < (1 << k); i ++)
cin >> g[i] >> val[i];
memset(f, -0x3f, sizeof(f));
for (int l = n; l >= 1; l --) {
f[l][l][a[l]] = 0;
for (int r = l + 1; r <= n; r ++) {
int len = (r - l) % (k - 1);
if (!len) len = k - 1;
for (int m = r; m > l; m -= k - 1)
for (int i = 0; i < (1 << len); i ++) {
f[l][r][i << 1] = max(f[l][r][i << 1], f[l][m - 1][i] + f[m][r][0]);
f[l][r][i << 1 | 1] = max(f[l][r][i << 1 | 1], f[l][m - 1][i] + f[m][r][1]);
}
if (len == k - 1) {
ll tmp[2];
tmp[0] = tmp[1] = -INF;
for (int i = 0; i < (1 << k); i ++)
tmp[g[i]] = max(tmp[g[i]], f[l][r][i] + val[i]);
f[l][r][0] = tmp[0], f[l][r][1] = tmp[1];
}
}
}
ll ans = -INF;
for (int i = 0; i < (1 << k); i ++)
ans = max(ans, f[1][n][i]);
cout << ans;
return 0;
}