洛谷 P1880 石子合并
题目
题目链接:P1880 [NOI1995] 石子合并
一道区间 DP 的典型题目。
区间 DP
特点
- 合并:将两个或多个部分进行整合,当然也可以反过来。
- 特征:能将问题分解为能两两合并的形式。
- 求解:对整个问题设最优值,枚举合并点,将问题分解为左右两个部分,最后合并两个部分的最优值得到原问题的最优值。
状态转移
下面,我们先考虑不在环上,而在一条链上的情况。
令状态 \(f(i,j)\) 表示将下标在 \([i,j]\) 区间的元素合并起来所能获得的最大价值,则 \(f(1,n)\) 就是问题的答案。状态转移式为:
\[f(i,j)=\max\{f(i,k)+f(k+1,j)+cost(i,j,k)\},\quad k\in[i,j)
\]
\(cost(i,j,k)\) 表示将区间 \([i,k]\) 和 \([k+1,j]\) 合并为 \([i,j]\) 的代价,这里的 \(k\) 就是要枚举的合并点。
递推求解
使用递推法求解区间 DP 时,通常的做法是从小到大枚举区间长度。这样能保证在求解大区间时,小区间的答案已经被求解出来了。
算法模板如下,时间复杂度 \(O(n^3)\)。
for (int len = 2; len <= n; len++)
{
for (int i = 1; i <= n; i++)
{
int j = i + len - 1;
for (int k = i; k < j && k <= n; k++)
{
f[i][j] = max(f[i][j], f[i][k] + f[k + 1][j] + cost(i, j, k));
}
}
}
环上的区间 DP
现在让我们回到原问题,怎么处理在环上的情况呢?
如果是在一个长为 \(n\) 的环上,那么弄一条长为 \(2n\) 的链(重复一次),DP 求解后取 \(f(1,n),f(2,n+1),\ldots,f(n-1,2n-1)\) 中的最优值即可。时间复杂度仍为 \(O(n^3)\)。
参考资料:区间 DP - OI Wiki
代码
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int INF = int(1e9);
const int maxn = 100 + 5;
int dp_min[maxn * 2][maxn * 2];
int dp_max[maxn * 2][maxn * 2];
int arr[maxn * 2];
int prefix_sum[maxn * 2];
//由于我们要把链重复一次,所以数组要开两倍大小
//预处理出前缀和
void calc_prefix_sum(int n)
{
prefix_sum[0] = 0;
for (int i = 1; i <= 2 * n - 1; i++)
prefix_sum[i] = prefix_sum[i - 1] + arr[i];
}
//[x, y] 的区间和
int sum(int x, int y)
{
return prefix_sum[y] - prefix_sum[x - 1];
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
int a;
scanf("%d", &a);
arr[i] = arr[n + i] = a;
}
calc_prefix_sum(n);
//递推求解
for (int len = 2; len <= n; len++)
{
for (int i = 1; i <= 2 * n - 1; i++)
{
int j = i + len - 1;
int temp_min = INF;
int temp_max = 0;
//枚举合并点 k,计算将 [i, k] 和 [k+1, j] 合并所需的花费
for (int k = i; k < j && k <= 2 * n - 1; k++)
{
temp_min = min(temp_min, dp_min[i][k] + dp_min[k + 1][j] + sum(i, j));
temp_max = max(temp_max, dp_max[i][k] + dp_max[k + 1][j] + sum(i, j));
}
dp_min[i][j] = temp_min;
dp_max[i][j] = temp_max;
}
}
int ans_min = dp_min[1][n];
int ans_max = dp_max[1][n];
for (int i = 2; i <= n - 1; i++)
{
ans_min = min(ans_min, dp_min[i][i + n - 1]);
ans_max = max(ans_max, dp_max[i][i + n - 1]);
}
printf("%d\n%d", ans_min, ans_max);
return 0;
}