NC19996 [HAOI2015]树上染色

题目链接

题目

题目描述

有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色。

将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。

输入描述

第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。
输入保证所有点之间是联通的。N ≤ 2000,0 ≤ K ≤ N

输出描述

输出一个正整数,表示收益的最大值。

示例1

输入

5 2
1 2 3
1 5 1
2 3 1
2 4 2

输出

17

说明

【样例解释】
将点1,2染黑就能获得最大收益。

备注

对于100% 的数据,\(0 \leq n,k \leq 2000\)

题解

知识点:树形dp,背包dp。

这道题长见识了,虽然显然是树上背包,但状态很妙。

\(dp[u][i]\) 为以 \(u\) 为根的子树选了 \(i\) 个点染黑后对答案的最大总贡献,注意这里是子树对答案的贡献,而非子树内的贡献。如果只是子树内的贡献,会发现转移时不知道点的具体位置,从而不能计算权值的变化,无法转移;而计算子树对整个答案的贡献就不需要考虑内点位置,而只要考虑父节点与子树根节点连的那条边的权值与子树染黑节点的数量即可转移。转移方程为:

\[dp[u][i] = \max(dp[u][i], dp[u][i - j] + dp[v][j] + val) \]

其中, \(val = j(m - j) \cdot w + (sz[v] - j)(n - m - (sz[v] - j)) \cdot w\) ,这里我把题目中的 \(K\) 改为了 \(m\) 方便使用。

首先转移方程表示为 \(u\) 为根的子树选了 \(i-j\) 个点染黑,\(v\) 为根的子树选了 \(j\) 个点染黑,并连接 \((u,v)\) 这条边产生 \(val\) 的总贡献是否更大。前面两个没什么问题,考虑 \(val\) 如何计算。

\((u,v)\) 的权为 \(w\) ,则 \((u,v)\) 两边的黑节点和白节点的每个组合都能多一次 \(w\) 的贡献。因此黑节点的组合贡献多了 \(j(m-j) \cdot w\) ,因为一端是子树 \(j\) 个黑节点,另一端是其他 \(m-j\) 个黑节点;白节点的组合贡献多了 \((sz[v] - j)(n - m - (sz[v] - j)) \cdot w\) ,因为子树有 \(sz[v]-j\) 个白节点,其他有 \(n-m-(sz[v]-j)\) 个白节点。最后加起来就是 \(val\)

到这里整道题算是做完了,但细节上有很多值得注意的。比如 \(j=0\) 时,会发现转移方程变为:

\[dp[u][i] = \max(dp[u][i], dp[u][i] + dp[v][0] + val) \]

直接原地更新了,这意味着更新 \(j=0\) 时, \(dp[u][i]\) 必须是原来的,这导致 \(j=0\) 必须第一个更新,才能更新别的,正序更新是直接满足这个要求。

除此之外还需要对dp范围进行剪枝,不然铁定超时。这里推荐用刷表法,因为最佳循环条件十分容易就能写出来(即没有浪费一点时间在没用的状态上),打表法不是不能写但非常麻烦,如下面我给出代码就是用打表法写的,虽然这道题用不着这么严格。

剪枝后的复杂度可以证明是 \(O(nm)\)

时间复杂度 \(O(nm)\)

空间复杂度 \(O(nm)\)

代码

#include <bits/stdc++.h>
#define ll long long

using namespace std;

int n, m;
vector<pair<int, int>> g[2007];
int sz[2007];
ll dp[2007][2007];

void dfs(int u, int fa) {
    sz[u] = 1;
    //dp[u][0] = dp[u][1] = 0;
    for (auto [v, w] : g[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
        for (int i = min(sz[u], m);i >= 0;i--) {
            for (int j = max(sz[v] - sz[u] + i, 0);j <= min(i, sz[v]);j++) {
                ///严格区间可以省掉许多时间(这道题是几百倍),但一般不敢这么严格,通常负无穷区间即可。
                ///不过这道题用刷表法后,严格区间会很好得到
                ///这道题需要加两个min限制一下不然会超时,但j的起点max是不必要的。
                ll val = 1LL * j * (m - j) * w + 1LL * (sz[v] - j) * (n - m - (sz[v] - j)) * w;
                dp[u][i] = max(dp[u][i], dp[u][i - j] + dp[v][j] + val);
                ///注意j=0时,dp[i][j]会被自己更新,如果j是倒序,会导致之前修改的重复作用,因此j=0必须第一个修改好,其他的顺序随意
            }
        }
    }
}

int main() {
    std::ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for (int i = 1;i < n;i++) {
        int u, v, w;
        cin >> u >> v >> w;
        g[u].push_back({ v,w });
        g[v].push_back({ u,w });
    }
    //memset(dp, -0x3f, sizeof(dp));
    dfs(1, 0);
    cout << dp[1][m] << '\n';
    return 0;
}
posted @ 2022-08-24 14:23  空白菌  阅读(69)  评论(0编辑  收藏  举报