洛谷 3177 [HAOI2015] 树上染色

题目描述

有一棵点数为 N 的树,树边有边权。给你一个在 0~ N 之内的正整数 K ,你要在这棵树中选择 K个点,将其染成黑色,并将其他 的N-K个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。

输入输出格式

输入格式:

 

第一行包含两个整数 N, K 。接下来 N-1 行每行三个正整数 fr, to, dis , 表示该树中存在一条长度为 dis 的边 (fr, to) 。输入保证所有点之间是联通的。

 

输出格式:

 

输出一个正整数,表示收益的最大值。

 

输入输出样例

输入样例#1:
3 1
1 2 1
1 3 2
输出样例#1:
3

说明

对于 100% 的数据, 0<=K<=N <=2000

 

题解

最终的收益等于每一条边的收益(权值乘以被用到的次数)的和,假设dp[i][j]表示以i为根的子树内有j个点为黑点时边的收益,从叶子到根的顺序计算每条边的收益,遍历到i点时,i的子树里的边已经被计算过了,i到其儿子的边被新加了进来,用这些新加的边的收益和子树的收益更新以i为根的子树的收益。

转移方程为dp[i][j]=dp[k][l]+边权*子树内黑点*子树外黑点+边权*子树内白点*子树外白点。

一开始我以为自己计算的是答案的两倍,这样智障了很久,后来把/2删掉后过了,我又仔细思考了一下,我是一条边一条边的算的,所以只会算一次。

#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#define nn 2010
#define mm 4010
#define lo long long
#define inf -100000000
using namespace std;
int e=0;
int fir[nn],nxt[mm],to[mm],w[mm],size[nn],n,k;
lo dp[nn][nn];
bool vis[nn]; 
int get()
{
    int ans=0,f=1;char ch=getchar();
    while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)) {ans=ans*10+ch-'0';ch=getchar();}
    return ans*f;
}
void add(int a,int b,int c)
{
    nxt[++e]=fir[a];fir[a]=e;to[e]=b;w[e]=c;
    nxt[++e]=fir[b];fir[b]=e;to[e]=a;w[e]=c;
}
void dfs(int o)
{
    size[o]=1;
    for(int i=fir[o];i;i=nxt[i])
      if(!vis[to[i]])
      {
          vis[to[i]]=1;
          dfs(to[i]);
          size[o]+=size[to[i]];
      }
}
void solve(int o)
{
    dp[o][0]=0;
    dp[o][1]=0;
    for(int i=fir[o];i;i=nxt[i])
      if(size[o]>size[to[i]])
      {
          solve(to[i]);
          for(int j=size[o];j>=0;j--)                //把size[o]写成了size[i] 
            for(int p=0;p<=size[to[i]]&&p<=j;p++)
              dp[o][j]=max(dp[o][j],dp[to[i]][p]+dp[o][j-p]+(lo)p*(k-p)*w[i]+(lo)w[i]*(size[to[i]]-p)*(n-k-size[to[i]]+p));
      }
}
int main()
{
    n=get(),k=get();
    int a,b,c;
    for(int i=1;i<n;i++)
    {
        a=get();b=get();c=get();
        add(a,b,c);
    }
    vis[1]=1;
    dfs(1);
    for(int i=1;i<=n;i++)
      for(int j=1;j<=n;j++)
        dp[i][j]=inf;
    solve(1);
    printf("%lld",dp[1][k]);
    return 0;
}

  

 

posted @ 2017-09-10 19:08  o00v00o  阅读(223)  评论(0编辑  收藏  举报