P3592 [POI2015] MYJ (区间dp)
upd 2024.11.13
区间 dp
这题的思路没有特别明显。可以从答案的计算出发,既然要消费,不妨设 \([1,n]\) 的最小值位置在 \(p\),那么在这里消费的顾客是可知的。
这时候剩下有消费的顾客一定在 \([1,p-1]\) 和 \([p+1,n]\)!此时就可以看出这其实就是区间 dp 的基本形式。
对于这类与区间最值有关的问题,也可以通过笛卡尔树的结构解题,这题里可以用它理解区间 dp 的过程。
状态好想,但是还需要优化值域的枚举。
值域的枚举有什么优化的地方?这时候就需要注意到枚举是否出现冗杂,这题是可以发现价格一定为出现过的 \(c_i\)。只需要枚举这些数就可以了啊!
要求总和最大,有两张思路:贪心和 dp。稍微想一下,发现贪心思考量太大,考虑 dp
观察 n 的数据范围,以及转移方式,可以想到区间 dp
发现转移跟区间最小值有关,设 \(f_{l,r,k}\) 为区间 \([l,r]\) 中最小值不小于 \(x\) 的答案。
转移枚举最小值的位置 \(p\),\(f_{l,r,k}=\max(f_{l,p-1,k}+f_{p+1,r,k}+cost(l,r,p,i))\)
\(cost(l,r,p,i)\) 表示最小值在 \(p\) 处时的贡献,即 \(l\le a_i\le p\le b_i\le r\) 且 \(c_i\ge k\) 的数量。转移时预处理 \(buc_{l,r}\) 表示 \([l,r]\) 中的区间数量即可。
还有转移 \(f_{l,r,k}=f_{l,r,k+1}\),容易理解。
现在的复杂度为 \(O(n^3\max(c_i))\),无法通过。发现对于每个位置,我们都只会选择在 \(c_i\) 中出现过的,否则一定不优。所以离散化 \(c_i\),复杂度降到 \(O(n^3m)\),可以通过。
答案为 \(f_{1,n,1}\)。
考虑打印方案,只需要在转移时记录断点以及断点处的值,dfs 一遍即可。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
typedef long long i64;
int n, m;
int a[4010], b[4010], c[4010], d[4010];
int f[60][60][4010], buc[60][60], ans[60];
pii pa[60][60][4010];
void dfs(int l, int r, int p) {
if(!pa[l][r][p].fi) return;
pii now = pa[l][r][p];
ans[now.fi] = d[now.se];
dfs(l, now.fi - 1, now.se), dfs(now.fi + 1, r, now.se);
}
void Solve() {
std::cin >> n >> m;
for(int i = 1; i <= m; i++) {
std::cin >> a[i] >> b[i] >> c[i];
d[i] = c[i];
}
std::sort(d + 1, d + m + 1);
for(int i = 1; i <= m; i++) c[i] = std::lower_bound(d + 1, d + m + 1, c[i]) - d;
for(int i = m; i >= 1; i--) {
for(int j = 1; j <= m; j++) {
if(c[j] == i) {
for(int l = 1; l <= a[j]; l++) {
for(int r = b[j]; r <= n; r++) {
buc[l][r]++;
}
}
}
}
for(int len = 1; len <= n; len++) {
for(int l = 1, r = len; r <= n; l++, r++) {
f[l][r][i] = f[l][r][i + 1], pa[l][r][i] = pa[l][r][i + 1];
for(int k = l; k <= r; k++) {
int cnt = buc[l][r] - buc[l][k - 1] - buc[k + 1][r];
int ret = f[l][k - 1][i] + f[k + 1][r][i] + cnt * d[i];
if(f[l][r][i] < ret) {
f[l][r][i] = ret;
pa[l][r][i] = {k, i};
}
}
if(!pa[l][r][i].fi) pa[l][r][i] = {r, i};
}
}
}
std::cout << f[1][n][1] << "\n";
dfs(1, n, 1);
for(int i = 1; i <= n; i++) std::cout << ans[i] << " \n"[i == n];
return;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
Solve();
return 0;
}