树形dp换根,求切断任意边形成的两个子树的直径——hdu6686
换根dp就是先任取一点为根,预处理出一些信息,然后在第二次dfs过程中进行状态的转移处理
本题难点在于任意割断一条边,求出剩下两棵子树的直径:
设割断的边为(u,v),设down[v]为以v为根的子树的直径长度,up[v]为u所在的子树的直径长度,那么down[v]就是很常规的子树直径的换根dp的求法,up[v]则要通过分情况讨论
第一种情况,组成up[v]的两条链,一条是u上方的链,一条是u下方且不属于子树v的链
第二种情况,组成up[v]的两条链都是u下方且不属于子树v的链
那么换根的过程中就要考虑这两种情况,我们必须维护u下的前三条最长的链,和u上一条最长的链,
因为如果v刚好在u深度最大的子树里,考虑第二种情况的up[v]时就要用到u下次长,次次长的链,而考虑第一种情况时就必须要用到u上的最长链,所以最少需要维护u下三条链+u上一条链
切断任意一条边,求剩下两个子树的直径
换根dp,第一次dfs预先处理出的数据dp[u][0|1|2]表示u的最长|次长|次次长 链
down[u]表示u子树里的最长链长度
第二次dfs求出
len[u][0|1]表示u子树里的不经过u的最长|次长 链
枚举每个儿子v,
求出dp[v][3]表示v上面的最长链
up[v]表示切断(u,v)后,u所在块的最长链
那么切断(u,v)后,两个子树的直径就是down[v],up[v]
down[v]好求,up[v]要通过u来求出
换根的过程:从u换到v时,up[v]有两种情况,一种是一条u的上面+一条u的下面,另一种是两条u的下面
#include<bits/stdc++.h> #include<vector> using namespace std; #define N 200005 vector<int>G[N]; int n,u[N],v[N],a[N],down[N],up[N],dp[N][4],len[N][2],d[N]; void dfs1(int u,int pre,int dep){ d[u]=dep; for(int i=0;i<G[u].size();i++){ int v=G[u][i]; if(v==pre)continue; dfs1(v,u,dep+1); int tmp=dp[v][0]+1; if(tmp>dp[u][0])swap(dp[u][0],tmp); if(tmp>dp[u][1])swap(dp[u][1],tmp); if(tmp>dp[u][2])swap(dp[u][2],tmp); down[u]=max(down[u],down[v]); } down[u]=max(down[u],dp[u][0]+dp[u][1]); } void dfs2(int u,int pre){ for(int i=0;i<G[u].size();i++){//求出u下不经过u的最长|次长 链 int v=G[u][i]; if(v==pre)continue; int tmp=down[v]; if(tmp>len[u][0])swap(tmp,len[u][0]); if(tmp>len[u][1])swap(tmp,len[u][1]); } for(int i=0;i<G[u].size();i++){//边(u,v)将原树分成两棵子树 int v=G[u][i]; if(v==pre)continue; //原树删掉v子树后,求v上的最长链dp[v][3],同时求出第一种情况的up[v] if(dp[u][0]==dp[v][0]+1){//v是u的最深子树 dp[v][3]=max(dp[u][3],dp[u][1])+1; up[v]=max(dp[u][2],dp[u][3])+dp[u][1]; } else if(dp[u][1]==dp[v][0]+1){//v是u的次深子树 dp[v][3]=max(dp[u][0],dp[u][3])+1; up[v]=max(dp[u][2],dp[u][3])+dp[u][0]; } else {//v是u的其他子树 dp[v][3]=max(dp[u][0],dp[u][3])+1; up[v]=max(dp[u][1],dp[u][3])+dp[u][0]; } //求第二种情况的up[v],也要特别判一下v子树里是否有u下的最长链 if(len[u][0]==down[v])up[v]=max(up[v],len[u][1]); else up[v]=max(up[v],len[u][0]); dfs2(v,u); } } void init(){ memset(dp,0,sizeof dp); memset(a,0,sizeof a); memset(len,0,sizeof len); memset(down,0,sizeof down); memset(up,0,sizeof up); for(int i=1;i<=n;i++)G[i].clear(); } int main(){ int t;cin>>t;while(t--){ init();cin>>n; for(int i=1;i<n;i++){ scanf("%d%d",&u[i],&v[i]); G[u[i]].push_back(v[i]); G[v[i]].push_back(u[i]); } dfs1(1,1,0); dfs2(1,1); for(int i=1;i<n;i++){ int x=u[i],y=v[i]; if(d[x]<d[y])swap(x,y); a[up[x]+1]=max(a[up[x]+1],down[x]+1); a[down[x]+1]=max(a[down[x]+1],up[x]+1); } long long ans=0; for(int i=n;i>=1;i--){ a[i]=max(a[i],a[i+1]); ans+=a[i]; } cout<<ans<<'\n'; } }