树形dp——覆盖所有边的最少费用(Protecting Zonk)

一、问题描述

有一个n(n<=10000)个节点的无根树。有两种装置A,B,每种都有无限多个。
1.在某个节点X使用A装置需要C1(C1<=1000)的花费,并且此时与节点X相连的边都被覆盖
2.在某个节点X使用B装置需要C2(C2<=1000)的花费,并且此时与节点X相连的边以及与节点X相连的点相连的边都被覆盖
求覆盖所有边的最小花费

二、问题分析

dp[u][0]:u没有安装装置,且u的子节点下的边都被覆盖
dp[u][1]:u安装装置A
dp[u][2]:u安装装置B
dp[u][3]:u没有安装装置,且v可以不安装装置

dp[u][0]=Sum( min(dp[v][1],dp[v][2]) );
dp[u][1]=min( C1+Sum( min(dp[v][0],dp[v][1],dp[v][2]) ),Sum( min(dp[v][2],dp[v][1],dp[v][0]) 且至少有一个子节点选择B) )
dp[u][2]=C2+Sum( min(dp[v][0],dp[v][1],dp[v][2],dp[v][3]) )
dp[u][3]=Sum(min(dp[v][0],dp[v][1],dp[v][2]))

三、代码实现

 1 #include<stdio.h>
 2 #include<iostream>
 3 #include<cstring>
 4 #include<algorithm>
 5 #include<vector>
 6 using namespace std;
 7 
 8 const int INF = 0x3f3f3f3f;
 9 const int maxn = 100000 + 10;
10 struct Edge
11 {
12     int to, next;
13 }e[maxn * 2];
14 
15 //dp[u][0]:u没有安装装置, 且u的子节点下的边都被覆盖
16 //dp[u][1] : u安装装置A
17 //dp[u][2] : u安装装置B
18 //dp[u][3] : u没有安装装置, 且v可以不安装装置
19 int head[maxn], d[maxn][4]; 
20 int tot,n,C1,C2;                                                  
21 
22 int mostmin(int a, int b, int c)
23 {
24     return min(a, min(b, c));
25 }
26 
27 void init()
28 {
29     tot = 0;
30     memset(head, -1, sizeof(head));
31 }
32 
33 void addadge(int from, int to)
34 {
35     e[tot].to = to;
36     e[tot].next = head[from];
37     head[from] = tot++;
38 }
39 
40 void dfs(int u, int fa)
41 {
42     d[u][0] = 0; d[u][1] = C1; d[u][2] = C2; d[u][3] = 0;
43     int flag = 0, sum = 0, mi = INF;
44     for (int i = head[u]; i != -1; i = e[i].next)
45     {
46         int v = e[i].to;
47         if (v == fa)    continue;
48         dfs(v, u);
49         d[u][0] += min(d[v][1], d[v][2]);
50         d[u][1] += mostmin(d[v][0], d[v][1], d[v][2]);
51         d[u][2] += min(mostmin(d[v][0], d[v][1], d[v][2]), d[v][3]);
52         d[u][3] += mostmin(d[v][0], d[v][1], d[v][2]);
53         int tmp = mostmin(d[v][0], d[v][1], d[v][2]);
54         sum += tmp;
55         mi = min(mi, d[v][2] - tmp);
56     }
57     sum += mi;
58     d[u][1] = min(d[u][1], sum);
59 }
60 
61 int main()
62 {
63     while (scanf("%d%d%d",&n,&C1,&C2) == 3 && n)
64     {
65         init();
66         int u, v;
67 
68         for (int i = 0; i < n - 1; i++)
69         {
70             scanf("%d%d", &u, &v);
71             addadge(u, v);
72             addadge(v, u);
73         }
74         dfs(1, 0);
75         printf("%d\n", mostmin(d[1][0], d[1][1], d[1][2]));
76     }
77     return 0;
78 }

 

posted @ 2018-08-16 22:50  Rogn  阅读(441)  评论(0编辑  收藏  举报