树形DP入门
一、基本概念
树形DP,即在树上进行DP。一般都用递归的形式进行实现,根据叶子节点的信息对根节点进行DP。
二、经典问题
1、树的重心
重心的定义:若删去树中的一个点,使得树中各联通块中的最大连通块的结点数最小,则称这个点为树的重心。
实现方法:将无根树转化为有根树,用dfs处理出每个点的的子树大小(size)以及它的所有子树的大小的最大值(maxv)。之后枚举每个点,那么删去它后分成的连通块大小的最大值便是max( maxv[i] , n-size[i] )。
例题:poj1655(模板题,没什么好说的),代码如下:
1 #include<cstdio> 2 #include<vector> 3 using namespace std; 4 5 const int MAXN=20010; 6 int t,n,size[MAXN],maxv[MAXN],ansx,ansy; 7 vector <int> ve[MAXN]; 8 9 int read(void) { 10 char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0'; 11 while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x; 12 } 13 14 void dfs(int u,int pre) { //dfs处理出每个点的size与maxv 15 size[u]=1; 16 for (int i=0;i<ve[u].size();++i) { 17 int v=ve[u][i]; 18 if (v==pre) continue; 19 dfs(v,u); 20 size[u]+=size[v]; 21 maxv[u]=max(maxv[u],size[v]); 22 } 23 } 24 25 int main() { 26 t=read(); 27 while (t--) { 28 ansy=2e9; 29 n=read(); 30 for (int i=1;i<=n;++i) ve[i].clear(),size[i]=0; 31 for (int i=1;i<n;++i) { 32 int x=read(),y=read(); 33 ve[x].push_back(y); 34 ve[y].push_back(x); 35 } 36 dfs(1,0); 37 for (int i=1;i<=n;++i) { //枚举每个点作为重心 38 int tmp=max(n-size[i],maxv[i]); 39 if (ansy>tmp) { 40 ansy=tmp; ansx=i; 41 } 42 } 43 printf("%d %d\n",ansx,ansy); 44 } 45 return 0; 46 }
2、树中的最长链
最长链的定义:找到两个点,使得这两个点的距离最远,则称这两个点之间的路径为最长链。
实现的方法:从任意一个点(一般为树根)开始dfs/bfs,找出距离这个点最远的点u。之后再进行一次dfs/bfs,找出离u最长的点v,则u到v的路径则为树的直径。
方法的证明:若起始点为树的最长链上的点,则进行两次寻找后找到的便是最长链上的两端。若起始点不为树的最长链上的点,则设起始点到点u之间的距离为l1,起始点到最长链上某一点(x)的距离为l2,点x到最长链另一端(两段中较长的一端)的距离为l3,最长链的长度为L。若点u不是最长链上的点,则可得 l1 > l2 + l3。假设起始点到u的路径与最长链有交,则 l1 - l2 > l3,即链上某一点到点u的距离大于链上某一点到链上距离较长端的距离,而 l3 >= L/2,所以 l1 - l2 > L/2,这与假设矛盾(l1 + l3的距离也是一条链的距离,而l3 > L/2,则 l1 > l/2 + l2 > L,这与L为最长链的长度的假设不符),所以 l1 + l2 > l3不成立。假设起始点到u的路径与最长链无交,即在以起始点为根的子树中,u与最长链处于不同的子树那最长链应为一条过起始点的最长链,与最长链的假设不符。所以,在一次dfs/bfs后找到的必为最长链上的某一点。
例题:poj1985(路的方向没有什么用,之间跑一遍树的最长链就好了),代码如下:
1 #include<cstdio> 2 #include<vector> 3 #define mp make_pair 4 #define fir first 5 #define sec second 6 #define pa pair <int,int> 7 using namespace std; 8 9 const int MAXN=40010; 10 int n,m,dep[MAXN],inx; 11 vector <pa> ve[MAXN]; 12 13 int read(void) { 14 char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0'; 15 while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x; 16 } 17 18 void dfs(int u,int pre) { //处理出每个点的距离 19 for (int i=0;i<ve[u].size();++i) { 20 int v=ve[u][i].fir; 21 if (v==pre) continue; 22 dep[v]=dep[u]+ve[u][i].sec; 23 dfs(v,u); 24 } 25 } 26 27 int main() { 28 n=read(); m=read(); 29 for (int i=1;i<=m;++i) { 30 int x=read(),y=read(),v=read(); 31 ve[x].push_back(mp(y,v)); //不要被路的方向迷惑了,要双向建边 32 ve[y].push_back(mp(x,v)); 33 } 34 dfs(1,0); 35 for (int i=1;i<=n;++i) 36 if (dep[i]>dep[inx]) inx=i; 37 for (int i=1;i<=n;++i) dep[i]=0; 38 dfs(inx,0); inx=0; 39 for (int i=1;i<=n;++i) 40 if (dep[i]>dep[inx]) inx=i; 41 printf("%d",dep[inx]); 42 return 0; 43 }
3、树的中心
树的中心的定义:树上的点到其它端点最大距离的值最小的点。
实现的方法:对于每个点,设up[i]表示除i的叶子节点外,其他节点到这个点的距离最大值;d1[i]表示这个点的叶子节点中到这个点的距离最大值;d2[i]表示这个点的叶子节点中到这个点的距离次大值;c1[i]表示叶子节点最大值所在的子树;c2[i]表示叶子节点次大值所在的子树;maxv[i]表示这个点到其他点距离的最大值。进行两次dfs,在第一次中求出每个点的d1,d2,c1,c2。假设节点u的叶子节点为v,若d1[v] + dist[v][u] > d1[u],则d2[u] = d1[u],c2[u] = c1[u],d1[u] = d1[v] + dist[v][u],c1[u] = v。若d1[v] + dist[v][u] > d2[u],则d2[u] = d1[v] + dist[v][u],c2[u] = v(不需要考虑d2[v] + dist[v][u]与d2[u]的关系,因为d1与d2均为u的子树中的最大与次大,所以只需比较d1即可)。之后再进行一次dfs,处理出每个结点的up值,设u是v的父节点,若c1[u] == v,则up[v] = max( d2[u] , up[u] ) + dist[u][v],若c1[u] != v,则up[v] = max( d1[u] , up[u] ) + dist[u][v]。最后每个结点距它最远的结点的距离为max( up[i] , d1[i] )。
例题:hdu2196,代码如下:
1 #include<cstdio> 2 #include<vector> 3 #include<algorithm> 4 #include<cstring> 5 #define pa pair <int,int> 6 #define fir first 7 #define sec second 8 #define mp make_pair 9 using namespace std; 10 11 const int MAXN=10010; 12 int n,d1[MAXN],d2[MAXN],c1[MAXN],c2[MAXN],up[MAXN],maxv[MAXN]; 13 vector <pa> ve[MAXN]; 14 15 int read(void) { 16 char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0'; 17 while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x; 18 } 19 20 void dfs1(int u,int pre) { 21 for (int i=0;i<ve[u].size();++i) { 22 int v=ve[u][i].fir; 23 if (v==pre) continue; 24 dfs1(v,u); 25 if (d1[v]+ve[u][i].sec>d1[u]) { //处理出最大与次大 26 d2[u]=d1[u]; c2[u]=c1[u]; 27 d1[u]=d1[v]+ve[u][i].sec; c1[u]=v; 28 } 29 else if (d1[v]+ve[u][i].sec>d2[u]) { 30 d2[u]=d1[v]+ve[u][i].sec; c2[u]=v; 31 } 32 //无需考虑d2[v]与d2[u]的大小关系 33 } 34 } 35 36 void dfs2(int u,int pre) { 37 for (int i=0;i<ve[u].size();++i) { 38 int v=ve[u][i].fir; 39 if (v==pre) continue; //处理出up 40 if (v==c1[u]) up[v]=max(up[u],d2[u])+ve[u][i].sec; 41 else if (v!=c1[u]) up[v]=max(up[u],d1[u])+ve[u][i].sec; 42 dfs2(v,u); 43 } 44 } 45 46 void init(void) { 47 memset(d1,0,sizeof(d1)); 48 memset(d2,0,sizeof(d2)); 49 memset(up,0,sizeof(up)); 50 for (int i=1;i<=n;++i) ve[i].clear(); 51 } 52 53 int main() { 54 while (scanf("%d",&n)!=EOF) { 55 init(); 56 for (int i=2;i<=n;++i) { 57 int x=read(),y=read(); 58 ve[i].push_back(mp(x,y)); 59 ve[x].push_back(mp(i,y)); 60 } 61 dfs1(1,0); 62 dfs2(1,0); 63 for (int i=1;i<=n;++i) maxv[i]=max(up[i],d1[i]); 64 for (int i=1;i<=n;++i) printf("%d\n",maxv[i]); 65 } 66 return 0; 67 }
4、树上背包
定义:树上背包,顾名思义,即在树上进行背包(好像没什么好说的)。
例题:CTSC1997选课(https://www.luogu.org/problemnew/show/P2014)
解析:加入一个虚拟的结点0连向每一棵树的根,则可将多棵树转化为一课有根树。设dp[u][j]表示以i为根的子树中选j个所获得的最大学分,则可在树上进行分组背包。枚举u的所有子树v以及子树中选了k个,则dp[u][j] = max( dp[u][j] , dp[u][j-k] + dp[v][k] )。注意要在dp的最后加上节点u的学分。(注意!n<=300)。
代码如下:
1 #include<cstdio> 2 #include<vector> 3 #include<algorithm> 4 using namespace std; 5 6 const int MAXN=301; 7 int n,m,val[MAXN],dp[MAXN][MAXN]; 8 vector <int> ve[MAXN]; 9 10 int read(void) { 11 char c; while (c=getchar(),c<'0' || c>'9'); int x=c-'0'; 12 while (c=getchar(),c>='0' && c<='9') x=x*10+c-'0'; return x; 13 } 14 15 void dfs(int u) { 16 dp[u][0]=0; //选0个课程学分为0 17 for (int i=0;i<ve[u].size();++i) { 18 int v=ve[u][i]; 19 dfs(v); 20 for (int j=m;j>=0;--j) //枚举u选了几个 21 for (int k=j;k>=0;--k) //枚举v选了几个 22 dp[u][j]=max(dp[u][j],dp[u][j-k]+dp[v][k]); 23 } 24 if (u) { //加上u节点本身的学分 25 for (int j=m;j>0;--j) dp[u][j]=dp[u][j-1]+val[u]; 26 } 27 } 28 29 int main() { 30 n=read(); m=read(); 31 for (int i=1;i<=n;++i) { 32 int x=read(); val[i]=read(); 33 ve[x].push_back(i); 34 } 35 dfs(0); 36 printf("%d",dp[0][m]); 37 return 0; 38 }