树形DP入门

一、基本概念

       树形DP,即在树上进行DP。一般都用递归的形式进行实现,根据叶子节点的信息对根节点进行DP。

二、经典问题

       1、树的重心

             重心的定义:若删去树中的一个点,使得树中各联通块中的最大连通块的结点数最小,则称这个点为树的重心。

             实现方法:将无根树转化为有根树,用dfs处理出每个点的的子树大小(size)以及它的所有子树的大小的最大值(maxv)。之后枚举每个点,那么删去它后分成的连通块大小的最大值便是max( maxv[i] , n-size[i] )。

例题:poj1655(模板题,没什么好说的),代码如下:

 1 #include<cstdio>
 2 #include<vector>
 3 using namespace std;
 4 
 5 const int MAXN=20010;
 6 int t,n,size[MAXN],maxv[MAXN],ansx,ansy;
 7 vector <int> ve[MAXN];
 8 
 9 int read(void) {
10     char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0';
11     while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x;
12 }
13 
14 void dfs(int u,int pre) { //dfs处理出每个点的size与maxv 
15     size[u]=1;
16     for (int i=0;i<ve[u].size();++i) {
17       int v=ve[u][i];
18         if (v==pre) continue;
19       dfs(v,u);
20       size[u]+=size[v];
21       maxv[u]=max(maxv[u],size[v]);
22     }
23 }
24 
25 int main() {
26     t=read();
27       while (t--) {
28           ansy=2e9;
29           n=read();
30             for (int i=1;i<=n;++i) ve[i].clear(),size[i]=0;
31             for (int i=1;i<n;++i) {
32                   int x=read(),y=read();
33                   ve[x].push_back(y);
34                   ve[y].push_back(x);
35             }
36         dfs(1,0);
37           for (int i=1;i<=n;++i) { //枚举每个点作为重心 
38               int tmp=max(n-size[i],maxv[i]);
39                 if (ansy>tmp) {
40                       ansy=tmp; ansx=i;
41                 }
42           }
43         printf("%d %d\n",ansx,ansy);
44       }
45     return 0;
46 } 

 

       2、树中的最长链

             最长链的定义:找到两个点,使得这两个点的距离最远,则称这两个点之间的路径为最长链。

             实现的方法:从任意一个点(一般为树根)开始dfs/bfs,找出距离这个点最远的点u。之后再进行一次dfs/bfs,找出离u最长的点v,则u到v的路径则为树的直径。

             方法的证明:若起始点为树的最长链上的点,则进行两次寻找后找到的便是最长链上的两端。若起始点不为树的最长链上的点,则设起始点到点u之间的距离为l1,起始点到最长链上某一点(x)的距离为l2,点x到最长链另一端(两段中较长的一端)的距离为l3,最长链的长度为L。若点u不是最长链上的点,则可得 l1 > l2 + l3。假设起始点到u的路径与最长链有交,则 l1 - l2 > l3,即链上某一点到点u的距离大于链上某一点到链上距离较长端的距离,而 l3 >= L/2,所以 l1 - l2 > L/2,这与假设矛盾(l1 + l3的距离也是一条链的距离,而l3 > L/2,则 l1 > l/2 + l2 > L,这与L为最长链的长度的假设不符),所以 l1 + l2 > l3不成立。假设起始点到u的路径与最长链无交,即在以起始点为根的子树中,u与最长链处于不同的子树那最长链应为一条过起始点的最长链,与最长链的假设不符。所以,在一次dfs/bfs后找到的必为最长链上的某一点。

例题:poj1985(路的方向没有什么用,之间跑一遍树的最长链就好了),代码如下:

 

 1 #include<cstdio>
 2 #include<vector>
 3 #define mp make_pair
 4 #define fir first
 5 #define sec second
 6 #define pa pair <int,int>
 7 using namespace std;
 8 
 9 const int MAXN=40010;
10 int n,m,dep[MAXN],inx;
11 vector <pa> ve[MAXN];
12 
13 int read(void) {
14     char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0';
15     while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x;
16 }
17 
18 void dfs(int u,int pre) { //处理出每个点的距离 
19     for (int i=0;i<ve[u].size();++i) {
20       int v=ve[u][i].fir;
21         if (v==pre) continue;
22       dep[v]=dep[u]+ve[u][i].sec;
23       dfs(v,u);
24     }
25 }
26 
27 int main() {
28     n=read(); m=read();
29       for (int i=1;i<=m;++i) {
30           int x=read(),y=read(),v=read();
31           ve[x].push_back(mp(y,v)); //不要被路的方向迷惑了,要双向建边 
32           ve[y].push_back(mp(x,v));
33       }
34     dfs(1,0);
35       for (int i=1;i<=n;++i) 
36         if (dep[i]>dep[inx]) inx=i;
37       for (int i=1;i<=n;++i) dep[i]=0;
38     dfs(inx,0); inx=0;
39       for (int i=1;i<=n;++i)
40         if (dep[i]>dep[inx]) inx=i;
41     printf("%d",dep[inx]);
42     return 0;
43 }

 

       3、树的中心

            树的中心的定义:树上的点到其它端点最大距离的值最小的点。

            实现的方法:对于每个点,设up[i]表示除i的叶子节点外,其他节点到这个点的距离最大值;d1[i]表示这个点的叶子节点中到这个点的距离最大值;d2[i]表示这个点的叶子节点中到这个点的距离次大值;c1[i]表示叶子节点最大值所在的子树;c2[i]表示叶子节点次大值所在的子树;maxv[i]表示这个点到其他点距离的最大值。进行两次dfs,在第一次中求出每个点的d1,d2,c1,c2。假设节点u的叶子节点为v,若d1[v] + dist[v][u] > d1[u],则d2[u] = d1[u],c2[u] = c1[u],d1[u] = d1[v] + dist[v][u],c1[u] = v。若d1[v] + dist[v][u] > d2[u],则d2[u] = d1[v] + dist[v][u],c2[u] = v(不需要考虑d2[v] + dist[v][u]与d2[u]的关系,因为d1与d2均为u的子树中的最大与次大,所以只需比较d1即可)。之后再进行一次dfs,处理出每个结点的up值,设u是v的父节点,若c1[u] == v,则up[v] = max( d2[u] , up[u] ) + dist[u][v],若c1[u] != v,则up[v] = max( d1[u] , up[u] ) + dist[u][v]。最后每个结点距它最远的结点的距离为max( up[i] , d1[i] )。

例题:hdu2196,代码如下:

 1 #include<cstdio>
 2 #include<vector>
 3 #include<algorithm>
 4 #include<cstring>
 5 #define pa pair <int,int>
 6 #define fir first
 7 #define sec second
 8 #define mp make_pair
 9 using namespace std;
10 
11 const int MAXN=10010;
12 int n,d1[MAXN],d2[MAXN],c1[MAXN],c2[MAXN],up[MAXN],maxv[MAXN];
13 vector <pa> ve[MAXN];
14 
15 int read(void) {
16     char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0';
17     while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x;
18 }
19 
20 void dfs1(int u,int pre) {
21     for (int i=0;i<ve[u].size();++i) {
22       int v=ve[u][i].fir;
23         if (v==pre) continue;
24       dfs1(v,u);
25         if (d1[v]+ve[u][i].sec>d1[u]) { //处理出最大与次大 
26           d2[u]=d1[u]; c2[u]=c1[u];
27           d1[u]=d1[v]+ve[u][i].sec; c1[u]=v;
28         } 
29         else if (d1[v]+ve[u][i].sec>d2[u]) {
30           d2[u]=d1[v]+ve[u][i].sec; c2[u]=v;
31         }
32         //无需考虑d2[v]与d2[u]的大小关系 
33     }
34 }
35 
36 void dfs2(int u,int pre) {
37     for (int i=0;i<ve[u].size();++i) {
38       int v=ve[u][i].fir;
39         if (v==pre) continue; //处理出up 
40         if (v==c1[u]) up[v]=max(up[u],d2[u])+ve[u][i].sec;
41         else if (v!=c1[u]) up[v]=max(up[u],d1[u])+ve[u][i].sec;
42       dfs2(v,u);
43     }
44 }
45 
46 void init(void) {
47     memset(d1,0,sizeof(d1));
48     memset(d2,0,sizeof(d2));
49     memset(up,0,sizeof(up));
50     for (int i=1;i<=n;++i) ve[i].clear();
51 }
52 
53 int main() {
54     while (scanf("%d",&n)!=EOF) {
55       init();
56         for (int i=2;i<=n;++i) {
57             int x=read(),y=read();
58             ve[i].push_back(mp(x,y));
59             ve[x].push_back(mp(i,y));
60         }
61       dfs1(1,0);
62       dfs2(1,0);
63         for (int i=1;i<=n;++i) maxv[i]=max(up[i],d1[i]);
64         for (int i=1;i<=n;++i) printf("%d\n",maxv[i]);
65     }
66     return 0;
67 } 

 

       4、树上背包

             定义:树上背包,顾名思义,即在树上进行背包(好像没什么好说的)。

例题:CTSC1997选课(https://www.luogu.org/problemnew/show/P2014

解析:加入一个虚拟的结点0连向每一棵树的根,则可将多棵树转化为一课有根树。设dp[u][j]表示以i为根的子树中选j个所获得的最大学分,则可在树上进行分组背包。枚举u的所有子树v以及子树中选了k个,则dp[u][j] = max( dp[u][j] , dp[u][j-k] + dp[v][k] )。注意要在dp的最后加上节点u的学分。(注意!n<=300)。

代码如下:

 1 #include<cstdio>
 2 #include<vector>
 3 #include<algorithm>
 4 using namespace std;
 5 
 6 const int MAXN=301;
 7 int n,m,val[MAXN],dp[MAXN][MAXN];
 8 vector <int> ve[MAXN];
 9 
10 int read(void) {
11     char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0';
12     while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x;
13 }
14 
15 void dfs(int u) {
16     dp[u][0]=0; //选0个课程学分为0 
17     for (int i=0;i<ve[u].size();++i) {
18       int v=ve[u][i];
19       dfs(v);
20         for (int j=m;j>=0;--j) //枚举u选了几个 
21           for (int k=j;k>=0;--k) //枚举v选了几个 
22             dp[u][j]=max(dp[u][j],dp[u][j-k]+dp[v][k]);
23     }
24     if (u) { //加上u节点本身的学分 
25       for (int j=m;j>0;--j) dp[u][j]=dp[u][j-1]+val[u];
26     }
27 }
28 
29 int main() {
30     n=read(); m=read();
31       for (int i=1;i<=n;++i) {
32           int x=read(); val[i]=read();
33           ve[x].push_back(i);
34       }
35     dfs(0);
36     printf("%d",dp[0][m]);
37     return 0;
38 } 

 

posted @ 2018-10-05 22:48  Gax_c  阅读(383)  评论(0编辑  收藏  举报