洛谷 P3177 [HAOI2015]树上染色

题目链接

题目描述

有一棵点数为 \(N\) 的树,树边有边权。给你一个在 \(0~ N\) 之内的正整数 \(K\) ,你要在这棵树中选择 \(K\)个点,将其染成黑色,并将其他 的\(N-K\)个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。

题解

有点难想的dp 我果然太菜了
%%%__stdcall

\(f[i][j]\) 为以\(i\)为根的子树, 选了染了\(j\)个黑点的最大贡献

然后就是树形背包。。

siz[u]为以u为根的子树大小


for (int j = Min(K, siz[u]); j >= 0; j--)
    for (int k = 0; k <= Min(j, siz[v]); k++)
	    if (f[u][j-k] >= 0) {
            long long val = 1ll*k*(K-k)*g[i].w + 1ll*(siz[v]-k)*(n-K+k-siz[v])*g[i].w;
            f[u][j] = Max(f[u][j], f[u][j-k] + f[v][k] + val);
	}

贡献为子树贡献加上该边的贡献(子树黑点个数 * 其它黑点个数 * 边权 + 子树白点个数 * 其它白点个数 * 边权 )

Code


#include<bits/stdc++.h>
#define LL long long
#define RG register
using namespace std;

inline int gi() {
    int f = 1, s = 0;
    char c = getchar();
    while (c != '-' && (c < '0' || c > '9')) c = getchar();
    if (c == '-') f = -1, c = getchar();
    while (c >= '0' && c <= '9') s = s*10+c-'0', c = getchar();
    return f == 1 ? s : -s;
}

const int N = 2010;

struct node {
    int to, next, w;
}g[N<<1];
int last[N], gl;
inline void add(int z, int x, int y) {
    g[++gl] = (node) {y, last[x], z};
    last[x] = gl;
    g[++gl] = (node) {x, last[y], z};
    last[y] = gl;
    return ;
}

int siz[N], n, K;
long long f[N][N];

inline void init(int u, int fa) {
    siz[u] = 1;
    for (int i = last[u]; i; i = g[i].next) {
        int v = g[i].to;
        if (v == fa) continue;
        init(v, u);
        siz[u] += siz[v];
    }
    return ;
}
#define Min(x, y) ((x<y)?x:y)
#define Max(x, y) ((x>y)?x:y)
inline void dfs(int u, int fa) {
    memset(f[u], 128, sizeof(f[u]));
    f[u][0] = f[u][1] = 0;
    for (int i = last[u]; i; i = g[i].next) {
        int v = g[i].to;
        if (v == fa) continue;
        dfs(v, u);
        for (int j = Min(K, siz[u]); j >= 0; j--)
            for (int k = 0; k <= Min(j, siz[v]); k++)
                if (f[u][j-k] >= 0) {
                    long long val = 1ll*k*(K-k)*g[i].w + 1ll*(siz[v]-k)*(n-K+k-siz[v])*g[i].w;
                    f[u][j] = Max(f[u][j], f[u][j-k] + f[v][k] + val);
                }				
    }
    return ;
}

int main() {
    n = gi(), K = gi();
    for (int i = 1; i < n; i++)
        add(gi(), gi(), gi());
    init(1, 0);
    dfs(1, 0);
    printf("%lld\n", f[1][K]);
    return 0;
}

posted @ 2018-10-29 22:10  zzy2005  阅读(153)  评论(0编辑  收藏  举报