牛客多校第10场J Wood Processing 分治优化/斜率优化 DP

题意:你有n块木头,每块木头有一个高h和宽w,你可以把高度相同的木头合并成一块木头。你可以选择一些木头消去它们的一部分,浪费的部分是 消去部分的高度 * 木头的宽度,问把n块木头变成恰好m块木头至少要浪费多少木料?

思路:把木头从高到第排序,设dp[i][j]为前i块木头合并成了j块木头的最小花费。因为从大到小排序,所以合并后最后一块木头的高度一定是合并前的第i块木头的高度。那么,容易得出dp转移方程:dp[i][j] = min(dp[k][j - 1] + cal(k, i)),其中cal(k, i)为把第k + 1块木头到第i块木头的高度变成一样的花费。直接转移O(n * n * m),需要优化。

1:分治优化:设op[i][j]为向dp[i][j]转移的状态中最优值中最小的k,若op[i][j] <= op[i + 1][j], 那么便可以进行分治优化dp。对于此题,dp[x][j] + cal(x, i)和dp[y][j] + cal(y, j)(x < y)cal(x, i)和cal(y, i)有重合部分,所以有op[i][j] <= op[i + 1][j], 通过分治的过程可以缩小转移的范围,复杂度O(n * logn * m)。

代码:

#include <bits/stdc++.h>
#define LL long long
#define pll pair<LL, LL>
#define INF 1e18
using namespace std;
const int maxn = 5010;
const int maxm = 2010;
pll a[maxn];
int n, m;
LL f[maxm][maxn], w[maxn], sum[maxn];
LL cal(LL l, LL r, LL h) {
	return sum[r] - sum[l] - h * (w[r] - w[l]);
}
void solve(int x, int l, int r, int opl, int opr) {
	if(l > r) return;
	int mid = (l + r) >> 1;
	pll ans = make_pair(INF, INF);
	for (int i = opl; i < mid && i <= opr; i++) {
		ans = min(ans, make_pair(f[x - 1][i] + cal(i, mid, a[mid].first), (LL)i));
	}
	f[x][mid] = ans.first;
	LL opt = ans.second;
	solve(x, l, mid - 1, opl, opt);
	solve(x, mid + 1, r, opt, opr);
}
int main() {
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%d%d", &a[i].second, &a[i].first);
	}
	sort(a + 1, a + 1 + n);
	reverse(a + 1, a + 1 + n);
	for (int i = 1; i <= n; i++) {
		w[i] = w[i - 1] + a[i].second;
		sum[i] = sum[i - 1] + a[i].second * a[i].first;
	}
	for (int i = 1; i <= n; i++)
		f[1][i] = cal(0, i, a[i].first);
	for (int i = 2; i <= m; i++) {
		solve(i, 1, n, 0, n);
	}
	printf("%lld\n", f[m][n]);
} 

思路2:斜率优化,把cal(k, i)式子列出来,用单调队列维护下凸包。场上没注意到斜率乘积会爆long long,非常可惜QAQ

#include <bits/stdc++.h>
#define LL long long
#define pll pair<LL, LL>
using namespace std;
const int maxn = 5010;
const int maxm = 2010;
pll a[maxn];
int n, m;
LL f[maxn][maxn], w[maxn], sum[maxn];
int q[maxn][maxm], l[maxm], r[maxm];
LL cal(LL x, LL y) {
    return f[x][y] - sum[x];
}
void update(int x, int y) {
    LL h = -a[x].first;
    while(l[y] < r[y]) {
        int p1 = q[y][l[y]], p2 = q[y][l[y] + 1];
        __int128 t = (__int128)cal(p2, y) - cal(p1, y);
        __int128 t1 = (__int128)h * (w[p2] - w[p1]);
        if(t <= t1) {
            l[y]++;
            continue;
        } else {
            break;
        }
    }
    int k = q[y][l[y]];
    f[x][y + 1] = f[k][y] + sum[x] - sum[k] + h * (w[x] - w[k]);
    while(l[y] < r[y]) {
        int p1 = q[y][r[y] - 1], p2 = q[y][r[y]];
        __int128 t = (__int128)(cal(p2, y) - cal(p1, y)) * (w[x] - w[p2]);
        __int128 t1 = (__int128)(cal(x, y) - cal(p2, y)) * (w[p2] - w[p1]);
        if(t >= t1) {
            r[y]--;
            continue;
        } else {
            break;
        }
    }
    q[y][++r[y]] = x;
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d%d", &a[i].second, &a[i].first);
    }
    sort(a + 1, a + 1 + n);
    reverse(a + 1, a + 1 + n);
    for (int i = 1; i <= m; i++) {
        l[i] = 1, r[i] = 1;
        q[i][1] = 0;
    }
    for (int i = 1; i <= n; i++) {
        w[i] = w[i - 1] + a[i].second;
        sum[i] = sum[i - 1] + a[i].second * a[i].first;
    }
    for (int i = 1; i <= n; i++) {
        f[i][1] = sum[i] - a[i].first * w[i];
        for (int j = 2; j <= m; j++) {
            update(i, j - 1);
        }
    }
    printf("%lld\n", f[n][m]);
}

  

posted @ 2019-08-22 16:53  维和战艇机  阅读(374)  评论(0编辑  收藏  举报