hiho1055/hdu1561 - 树形dp转换成背包

题目链接

输入:一棵树,每个节点一个权值。

输出:包括1号节点在内的m个节点组成的连通分量的权值和的最大值

hdu1561和hiho1055一样,只是变换了下说法

/**********************************************/

计 dp(i,j) 为以i为根的子树选中j个点(包括i)时的最大权值和。则dp(1,m)即为所求。

方程:

{

      dp[i][0] = 0;

      dp[i][1] = value[i];

      foreach child c of i

           for j = m...2

               for k = 1...j-1

                    dp[i][j] = max(dp[i][j],dp[i][j-k]+dp[c][k])

}

因为dp的核心就是记忆化搜索,所以自下向上处理整棵树,处理完一个节点就标记一下,下次用到这个节点的时候就不用再递归了。

这里我用getCnt()函数计算了一下以每个节点i为根的子树中节点的数目cnt(i),为的是缩小求dp(i,j)中j和k的上限,由m变为MIN(m,cnt(i)),应该不会提速多少

 

#include <set>
#include <map>
#include <stack>
#include <queue>
#include <cmath>
#include <vector>
#include <string>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>

#define MAX(a,b) ((a)>=(b)?(a):(b))
#define MIN(a,b) ((a)<=(b)?(a):(b))
#define OO 0x0fffffff
using namespace std;
const int N = 111;

struct Edge{
    int to;
    int next;
};
int eid = 0;
Edge edges[N*2];
int heads[N];
void addEdge(int a,int b){
     edges[eid].to = a;
     edges[eid].next = heads[b];
     heads[b] = eid++;

     edges[eid].to = b;
     edges[eid].next = heads[a];
     heads[a] = eid++;
}

int m,n;
int dp[N][N],cnt[N],visited[N];

void getCnt(int id){
    visited[id] = 1;
    for(int cur = heads[id];cur!=-1;cur=edges[cur].next){
        int cid = edges[cur].to;
        if(!visited[cid]) {
           if(!cnt[cid]) getCnt(cid);
           cnt[id] += cnt[cid];
        }
    }
    cnt[id] += 1;
}
void traverse(int id){
    visited[id] = 1;
    for(int cur=heads[id];cur!=-1;cur=edges[cur].next){
        int cid = edges[cur].to;
        if(!visited[cid]){
            if(!visited[cid]) traverse(cid);
            for(int i=MIN(m,cnt[id]);i>=2;i--){
                for(int j=1;j<MIN(i,cnt[cid]+1);j++){
                   dp[id][i]=MAX(dp[id][i],(dp[id][i-j]+dp[cid][j]));
                }
            }
        }
    }
}
int main(){
    scanf("%d%d",&n,&m);
    memset(dp,0,sizeof(dp));
    for(int i=1;i<=n;i++) scanf("%d",dp[i]+1);

    int a,b;
    memset(heads,-1,sizeof(heads));
    for(int i=0;i<n-1;i++){
        scanf("%d%d",&a,&b);
        addEdge(a,b);
    }

    memset(cnt,0,sizeof(cnt));
    memset(visited,0,sizeof(visited));
    getCnt(1);

    memset(visited,0,sizeof(visited));
    traverse(1);

    printf("%d\n",dp[1][m]);
    return 0;
}

 

posted @ 2017-05-13 17:48  redips  阅读(173)  评论(0编辑  收藏  举报