POJ 2486 Apple Tree [树状DP]
题目:一棵树,每个结点上都有一些苹果,且相邻两个结点间的距离为1。一个人从根节点(编号为1)开始走,一共可以走k步,问最多可以吃多少苹果。
思路:这里给出数组的定义:
dp[0][x][j] 为从结点x开始走,一共走j步,且j步之后又回到x点时最多能吃到的苹果数。
dp[1][x][j] 为从结点x开始走,一共走j步最多能吃到的苹果数(不必再回到x点)。之所以要定义上面的一种状态是因为在求第二种状态时需要用到。
下面介绍递推公式。
对于结点x,假设它目前要访问的孩子为y,则1...(y-1)已经遍历过。此时有:
dp[0][x][j+2] = max(dp[0][x][j], dp[0][x][m] + dp[0][y][j-m])
注:dp[0][x][m]在每次dp后都会进行更新,此时的dp[0][x][m]实际上只是遍历过孩子结点1...(y-1)的情况。等号左边j之所以要加2是因为右面的总距离没有考虑从点x到y以及从y再回到x的距离,在这里要加上。
dp[1][x][j+1] = max(dp[1][x][j+1], dp[0][x][m] + dp[1][y][j-m])
注:遍历y结点,且不再回来。j加1表示只需要走一次从x到y的边。
dp[1][x][j+2] = max(dp[1][x][j+2], dp[1][x][m] + dp[0][y][j-m])
注:遍历y结点且又回到结点x,则必然是从1...(y-1)中的某个结点有去无回。
另外在dp的过程中,对于结点x的每个孩子都需要枚举走过的步数。而这个步数的枚举从小到大还是从大到小结果是不一样的。这就要看到前面递推公式里的定义,要求dp[0/1][x][m]存储的是遍历y前面的孩子结点的dp值。因此步数要从大到小枚举,这样每次值更新后都不会影响到下次dp。不然会wa。当然,也可以在每次dp前将dp[0/1][x][m]这两个值存储起来,这样就不用考虑这个问题了。
1 #include<stdio.h> 2 #include<string.h> 3 #include<algorithm> 4 #define maxn 105 5 #define maxk 205 6 using namespace std; 7 int n, k; 8 int w[maxn]; 9 bool vis[maxn]; 10 struct node 11 { 12 int v, next; 13 }edge[maxn<<1]; 14 int num_edge, head[maxn]; 15 void init_edge() 16 { 17 num_edge = 0; 18 memset(head, -1, sizeof(head)); 19 } 20 void addedge(int a,int b) 21 { 22 edge[num_edge].v = b; 23 edge[num_edge].next = head[a]; 24 head[a] = num_edge++; 25 } 26 int dp[2][maxn][maxk]; 27 void getdp(int x) 28 { 29 vis[x] = 1; 30 for (int i = 0; i <= k; i++) 31 dp[0][x][i] = dp[1][x][i] = w[x]; 32 for (int i = head[x]; i != -1; i = edge[i].next) 33 { 34 int v = edge[i].v; 35 if (vis[v]) continue; 36 getdp(v); 37 for (int j = k; j >= 0; j--) 38 for (int m = 0; m <= j; m++) 39 { 40 dp[0][x][j+2] = max(dp[0][x][j+2], dp[0][x][m] + dp[0][v][j-m]); 41 dp[1][x][j+1] = max(dp[1][x][j+1], dp[0][x][m] + dp[1][v][j-m]); 42 dp[1][x][j+2] = max(dp[1][x][j+2], dp[1][x][m] + dp[0][v][j-m]); 43 } 44 } 45 } 46 int main() 47 { 48 while (~scanf("%d%d",&n,&k)) 49 { 50 init_edge(); 51 memset(vis, 0, sizeof(vis)); 52 for (int i = 1; i <= n; i++) 53 scanf("%d",&w[i]); 54 for (int i = 1; i < n; i++) 55 { 56 int a, b; 57 scanf("%d%d",&a,&b); 58 addedge(a, b); 59 addedge(b, a); 60 } 61 getdp(1); 62 printf("%d\n", dp[1][1][k]); 63 } 64 return 0; 65 }