[POI2015] MYJ
题面描述:
有 \(n\) 家洗车店从左往右排成一排,每家店都有一个正整数价格 \(p_i\)。有 \(m\) 个人要来消费,第 \(i\) 个人会驶过第 \(a_i\) 个开始一直到第 \(b_i\) 个洗车店,且会选择这些店中最便宜的一个进行一次消费。但是如果这个最便宜的价格大于 \(c_i\),那么这个人就不洗车了。请给每家店指定一个价格,使得所有人花的钱的总和最大。
解题思路:
首先要注意到 \(n\) 只有 \(50\),然后发现 \(Ci\) 很大,但是 \(m\) 只有 \(40000\),所以可以离散化。
再思考,发现每个点的价格只会是所有覆盖它的区间的 \(Ci\) 中的一个。
并且,他全部是区间,所以我们可以考虑区间 \(dp\)。
设状态 \(f_{i,j,k}\) 表示区间 \([i,j]\),中的最小值为 \(= k\) 所获得的价值。
设状态 \(dp_{i,j,k}\) 表示区间 \([i,j]\),中的最小值为 \(\ge k\) 所获得的价值。
因为要求方案,所以记一个 \(g[i][j]\) 表示区间 \([i,j]\) 的断点在哪(等价于记录从哪里转移来的)。
然后我们区间 \(dp\) 枚举区间长度,然后枚举最小值(从大到小)和其位置。
那么转移很显然。
\(f_{i,j,k}=max(dp_{i,mid-1,k}+dp_{mid+1,j,k}+cnt_{mid,k}\times C_{k})\)
\(dp_{i,j,k}=max(dp_{i,j,k+1},f_{i,j,k})\)
上方 \(cnt\) 数组指的是在区间 \([i,j]\) 中断点在 \(mid\),最小值 \(\ge k\) 经过这个点的线段数量。
\(cnt\) 数组也是非常好处理啊,先直接扫一遍全部线段,然后把两端在 \([i,j]\) 内的线段抓出来,因为对于某一个线段中间的所有点作为断点时,这个线段都会被记录所以直接加一遍。然后因为我们记录的是最小值 \(\ge k\) 的线段数量,所以还要累加一下。
最后答案 $ \max _{1} ^{\text{max}\text{{c_i}}} f[i][j][k]$。
方案直接递归求解。
代码实现:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 55, M = 4e3 + 5, P = 5e5 + 5;
int n, m;
int tmp[M];
int cnt[N][M];//cnt[mid][K]表示区间[i,j]断点为mid最小值为k的cnt
int f[N][N][M], dp[N][N][M], g[N][N][M], col[N][N][M];
//f[i][j][k]表示区间[i,j]中最小值==tmp[k]的最大花费和
//dp[i][j][k]表示区间[i,j]中最小值>=tmp[k]的最大花费和
int a[M], b[M], c[M], ansid[N];
void lsh(){
for(int i = 1; i <= m; i++) tmp[i] = c[i];
sort(tmp + 1, tmp + 1 + m);
int len = unique(tmp + 1, tmp + 1 + m) - tmp - 1;
for(int i = 1; i <= m; i++)
c[i] = lower_bound(tmp + 1, tmp + 1 + len, c[i]) - tmp;
}
void dfs(int l, int r, int op){
if(l > r) return;
ansid[g[l][r][op]] = tmp[col[l][r][op]];
dfs(l, g[l][r][op] - 1, col[l][r][op]);
dfs(g[l][r][op] + 1, r, col[l][r][op]);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> m;
for(int i = 1; i <= m; i++)
cin >> a[i] >> b[i] >> c[i];
lsh();
for(int len = 1; len <= n; len++){
for(int i = 1; i + len - 1 <= n; i++){
int j = i + len - 1;
memset(cnt, 0, sizeof cnt);
for(int k = 1; k <= m; k++)
if(a[k] >= i && b[k] <= j)
for(int l = a[k]; l <= b[k]; l++)
cnt[l][c[k]]++;
for(int l = i; l <= j; l++)
for(int k = m - 1; k >= 1; k--)
cnt[l][k] += cnt[l][k + 1];
for(int k = m; k >= 1; k--){
for(int l = i; l <= j; l++){
if(f[i][j][k] <= dp[i][l - 1][k] + dp[l + 1][j][k] + cnt[l][k] * tmp[k]){
f[i][j][k] = dp[i][l - 1][k] + dp[l + 1][j][k] + cnt[l][k] * tmp[k];
g[i][j][k] = l;
}
}
if(dp[i][j][k + 1] > f[i][j][k]){
dp[i][j][k] = dp[i][j][k + 1];
g[i][j][k] = g[i][j][k + 1];
col[i][j][k] = col[i][j][k + 1];
}else{
dp[i][j][k] = f[i][j][k];
col[i][j][k] = k;
}
}
}
}
int ans = -0x3f3f3f3f3f3f;
for(int i = 1; i <= m; i++)
ans = max(ans, dp[1][n][i]);
cout << ans << endl;
dfs(1, n, 1);
for(int i = 1; i <= n; i++) cout << ansid[i] << " ";
cout << endl;
return 0;
}