poj 4045 (树形DP)

先选一点为根节点找出所有父节点i到下面所有点距离和dp[i],该父节点下面有多少个点Node[i]。

然后求出所有节点的所有非子节点到该点的距离dp1[v]+=(dp1[u]+(dp[u]-dp[v]-Node[v]-1)+n-Node[v]-1)

dp[u]-dp[v]-Node[v]-1:u的子节点中除了v这一部分子节点到u的距离

n-Node[v]-1:非v的字节点的个数







 

#include<stdio.h>
#include<string.h>
#define N 50002
#define inf 0x3fffffff
int head[N],num,vis[N],dp[N],Node[N],dp1[N],n,I,R;
struct edge
{
	int st,ed,next;
}E[N*2];
void addedge(int x,int y)
{
	E[num].st=x;
	E[num].ed=y;
	E[num].next=head[x];
	head[x]=num++;
}
void dfs(int u)
{
	vis[u]=1;
	int i,v;
	for(i=head[u];i!=-1;i=E[i].next)
	{
		v=E[i].ed;
		if(vis[v]==1)continue;
		dfs(v);
		dp[u]+=(dp[v]+Node[v]+1);//所有子节点到到父节点的距离
		Node[u]+=(Node[v]+1);//子节点个数
	}
}
long long  mm;
void dfs1(int u)
{
	int i,v;
	vis[u]=1;
	for(i=head[u];i!=-1;i=E[i].next)
	{
		v=E[i].ed;
		if(vis[v]==1)continue;
		dp1[v]+=(dp1[u]+(dp[u]-dp[v]-Node[v]-1)+n-Node[v]-1);//除了子节点外所有节点到该点的距离
		dfs1(v);
	}
	if(mm>dp[u]+dp1[u])
		mm=dp[u]+dp1[u];
}
int main()
{
	int i,x,y,t;
	scanf("%d",&t);
	while(t--)
	{
		scanf("%d%d%d",&n,&I,&R);
		memset(head,-1,sizeof(head));
		num=0;
		for(i=1;i<n;i++)
		{
			scanf("%d%d",&x,&y);
			addedge(x,y);
			addedge(y,x);
		}
		memset(dp,0,sizeof(dp));
		memset(dp1,0,sizeof(dp1));
		memset(Node,0,sizeof(Node));
		memset(vis,0,sizeof(vis));
		mm=inf;
		dfs(1);
		memset(vis,0,sizeof(vis));
		dfs1(1);
		printf("%lld\n",I*I*R*mm);
		for(i=1;i<=n;i++)
		{
		  if(dp[i]+dp1[i]==mm)
			  printf("%d ",i);
		}
		printf("\n\n");
	}
	return 0;
}


 

 

posted on 2013-07-31 19:30  you Richer  阅读(185)  评论(0编辑  收藏  举报