[HAOI2015]树上染色(树形背包)

有一棵点数为 N 的树,树边有边权。给你一个在 0~ N 之内的正整数 K ,你要在这棵树中选择 K个点,将其染成黑色,并将其他 的N-K个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。

Solution

比较经典的树形背包问题。

如果只对点进行分析,情况会变得十分麻烦,不放考虑每条变的贡献,每条边会产生两边黑点数的乘积加上两边白点数的乘积。

这样的话我们直接跑背包就可以了,标准的树形背包是n^3的,但是这道题每颗字数背包体积有上限,总复杂度可以做到n^2.

Code

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 2009
using namespace std;
long long dp[N][N];
int size[N],m,n,a,b,c,tot,head[N];
struct dsd
{
    int n,to,l;
}an[N<<1];
inline void add(int u,int v,int l)
{
    an[++tot].n=head[u];
    an[tot].to=v;
    head[u]=tot;
    an[tot].l=l;
}
void dfs(int u,int fa)
{
  size[u]=1;
  for(int i=head[u];i;i=an[i].n)
  if(an[i].to!=fa)
  {
      int v=an[i].to;
      dfs(v,u);
    size[u]+=size[v];
    for(int j=min(m,size[u]);j>=0;--j)//
      for(int k=0;k<=min(j,size[v]);++k)
       if(dp[v][k]!=-0x3f3f3f3f)
      {
          long long num=(long long)(k*(m-k)+(n-size[v]-(m-k))*(size[v]-k))*an[i].l;
          dp[u][j]=max(dp[u][j],dp[u][j-k]+dp[v][k]+num);
      }
  } 
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;++i)
    {
      scanf("%d%d%d",&a,&b,&c);
      add(a,b,c);add(b,a,c);
    }
    memset(dp,-0x3f,sizeof(dp));
    for(int i=1;i<=n;++i)
      dp[i][0]=dp[i][1]=0;//
    dfs(1,0);
    cout<<dp[1][m];
    return 0;
} 

 这种写法太慢了,并没有做到严格n^2,bzoj会TLE,下面这种写法是稳过的。

Code

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 2009
using namespace std;
typedef long long ll;
ll dp[N][N],size[N],m,n,a,b,c,tot,head[N],g[N];
struct dsd
{
    ll n,to,l;
}an[N<<1];
inline void add(ll u,ll v,ll l)
{
    an[++tot].n=head[u];
    an[tot].to=v;
    head[u]=tot;
    an[tot].l=l;
}
ll mi(ll x,ll y){return x<y?x:y;}
ll ma(ll x,ll y){return x<y?y:x;}
void dfs(ll u,ll fa){
  size[u]=1;
  for(ll i=head[u];i;i=an[i].n)
  if(an[i].to!=fa){
      ll v=an[i].to;
      dfs(v,u);
    ll x=mi(m,size[u]),y=mi(m,size[v]);
    for(int j=0;j<=m;++j)g[j]=0;
    for(ll j=x;j>=0;--j)
      for(int k=0;k<=y;++k)if(j+k<=m){
          ll gyx=((ll)k*(m-k)+(n-size[v]-(m-k))*(size[v]-k))*an[i].l;
          g[j+k]=ma(g[j+k],dp[u][j]+dp[v][k]+gyx);
      }
    for(int j=0;j<=m;++j)dp[u][j]=g[j];
    size[u]+=size[v];
  } 
}
inline int rd(){
    int x=0;char c=getchar();
    while(!isdigit(c))c=getchar();
    while(isdigit(c)){
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }
    return x;
} 
int main()
{
    n=rd();m=rd();
    for(int i=1;i<n;++i){
      a=rd();b=rd();c=rd();
      add(a,b,c);add(b,a,c);
    }
    memset(dp,-0x3f,sizeof(dp));
    for(int i=1;i<=n;++i)
      dp[i][0]=dp[i][1]=0;
    dfs(1,0);
    printf("%lld",dp[1][m]);
    return 0;
} 

 

posted @ 2018-10-02 19:16  comld  阅读(166)  评论(0编辑  收藏  举报