一个关于序列的游戏——DP综合题
题目
有一个序列,你可以在上面删除符合要求的连续段若干次。每次删除都会得到连续段长度对应的分数。
需要符合的要求为:
1、相邻两个元素相差为1
2、如果某个元素不在连续段的最左或最右,那么这个元素就不能同时小于相邻的左右两个元素。
“1、2、3、4、3” “1、2” “3、2” “3”都符合条件。
显然,删除掉连续段后,这个段的左边和右边并在一起成为相邻元素。
你的任务是对于给出的序列,计算出可能获得的最大总分。
题解
首先肯定是一道区间DP,不过如果我们直接维护最大收益的话,感觉在转移时非常困难。
仔细思考可以发现,最大收益应该是由几段被完全删去的区间所拼接而成,这个拼接的过程可以转化为\(O(n^3)\)的区间DP(就是个板子),考虑求解完全删去一段的最大收益
设:\(f[l,r]\)表示完全删去区间\([l,r]\)的最大收益,\(ans[l,r]\)为区间\([l,r]\)的最大收益(最终答案)
关于\(f[l,r]\)的转移,可以分成两种情况,一是\(l,r\)最后被一次取掉,二是\([l,k],[k+1,r]\)分别取,下面讨论第一种如何转移
其实本质上,一个连续段就是由两个公差为1的等差数列组成的,并且具备先递增后递减的单峰性质,而区间DP在内层本就是在枚举断点\(k\),可以让\([l,k]\)为递增的一段,\([k,r]\)为递减的一段。所以我们可以处理出递增一段的最大收益与递减一段的最大收益来处理,设递增的收益为\(up[l,r]\),递减的为\(down[l,r]\)。这里因为公差为1,可以利用这个性质,强制命\(up[l,r]\)与\(down[l,r]\)为不能删去\(l,r\)两个位置的最大收益,有了这个,我们就可以快速计算出\([l,r]\)断点为\(k\)时,这个序列长度为\(a[k]\times 2-a[l]-a[r]+1\),进而求出解。
那么转移就可以写成:
int get(int l,int r,int k){
if(a[k]+a[k]-a[l]-a[r]+1<0)return -inf;
if(a[k]<a[l]||a[k]<a[r])return -inf;
return up[l,k]+down[k,r]+val[a[k]+a[k]-a[l]-a[r]+1];
}
//在solve中
for(int len=2;len<=n;len++){
for(int l=1,r=len;r<=n;l++,r++){
for(int k=l;k<=r;k++){
f[l][r]=max(f[l][r],get(l,r,k));
if(k<r)f[l][r]=max(f[l][r],f[l][k]+f[k+1][r]);
}
}
}
那么问题就变为了如何维护\(up,down\),注意\(up,down\)内部被删除的几个区间并不一定是连续的。考虑如何维护。这里又可以运用上一个性质:相同的\(a_i\)最多出现7次,此时我们可以考虑用新的\(f[l,r]\)来更新以后的\(up,down\).那么\(f[l,r]\)能更新\(up\)的充要条件就在于\(a[l-1]=a[r+1]-1\),\(down\)同理,此时我们就可以枚举包含\(l,r\)的区间,假设为\(L,R\),那么可以用\(up[L,l-1]+f[l,r]+up[r+1,R]\)来更新\(up[L,R]\);同理当\(a[l-1]-1=a[r+1]\)成立时,用\(down[L,l-1]+f[l,r]+down[r+1,R]\)更新\(down[L,R]\).由于相同\(a_i\)最多出现7次,那么这个最多让常数大几倍
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 250
#define inf 0x3f3f3f3f
int f[N][N], up[N][N], down[N][N], ans[N][N], n, m, a[N], val[N];
int get(int l, int r, int k) {
if (a[k] + a[k] - a[l] - a[r] + 1 > n)return -inf;
if (a[k] < a[l] || a[k] < a[r])return -inf;
return up[l][k] + down[k][r] + val[a[k] + a[k] - a[l] - a[r] + 1];
}
void init() {
cin >> n;
for (int i = 1; i <= n; i++)cin >> val[i];
for (int i = 1; i <= n; i++)cin >> a[i];
}
//在solve中
void solve() {
memset(f, 0xcf, sizeof f);
memset(up, 0xcf, sizeof f);
memset(down, 0xcf, sizeof f);
for (int i = 1; i <= n; i++)f[i][i - 1] = 0;
f[n + 1][n] = 0;
for (int i = 1; i <= n; i++)up[i][i] = down[i][i] = 0;
for (int len = 0; len <= n; len++) {
for (int l = 1, r = len; r <= n; l++, r++) {
for (int k = l; k <= r; k++) {
f[l][r] = max(f[l][r], get(l, r, k));
if (k < r)f[l][r] = max(f[l][r], f[l][k] + f[k + 1][r]);
}
if (a[l - 1] == a[r + 1] + 1) {
for (int L = 1; L < l; L++) {
for (int R = r + 1; R <= n; R++) {
down[L][R] = max(down[L][R], down[L][l - 1] + f[l][r] + down[r + 1][R]);
}
}
}
if (a[l - 1] == a[r + 1] - 1) {
for (int L = 1; L < l; L++) {
for (int R = r + 1; R <= n; R++) {
up[L][R] = max(up[L][R], up[L][l - 1] + f[l][r] + up[r + 1][R]);
}
}
}
// printf("%d ", f[l][r]);
}
// puts("");
}
for (int len = 1; len <= n; len++) {
for (int l = 1, r = len; r <= n; l++, r++) {
ans[l][r] = max(ans[l][r], f[l][r]);
for (int k = l; k < r; k++) {
ans[l][r] = max(ans[l][r], ans[l][k] + ans[k + 1][r]);
}
}
}
}
int main() {
init();
solve();
printf("%d\n", ans[1][n]);
return 0;
}
说明:由于我们是用现成的\(f[i,j]\)更新以后的\(up,down\),所以必须从长度为0的区间开始更新(表示这里不删),故将长度为0的区间的\(f\)值变成0,然后将长度为1的\(down,up\)值设为0(这个是为了更新长度为1的\(f\)).