codeforces 161D Distance in Tree 树形dp

题目链接:

http://codeforces.com/contest/161/problem/D

D. Distance in Tree

time limit per test 3 seconds
memory limit per test 512 megabytes
#### 问题描述 > A tree is a connected graph that doesn't contain any cycles. > > The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices. > > You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair. #### 输入 > The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices. > > Next n - 1 lines describe the edges as "ai bi" (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different. #### 输出 > Print a single integer — the number of distinct pairs of the tree's vertices which have a distance of exactly k between them. > > Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier. #### 样例 > **sample input** > 5 2 > 1 2 > 2 3 > 3 4 > 2 5 > > **sample output** > 4

题意

给你一颗树,每条边长为1,求所有距离为k的顶点对,(u,v)和(v,u)算一对。

题解

树形dp:
dp[i][j]表示与第i个节点距离为j的节点数。
两次dfs:
第一次求以i为根的子树中与i距离为j的节点数dp[i][j]。
第二次求i与不在i的子树中的节点金额距离为j的节点数。
两次加起来就是表示与i节点距离为j的所有的树上节点数。
答案就是sigma(dp[i][k])。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<map>
#define lson (o<<1)
#define rson ((o<<1)|1)
#define M (l+(r-l)/2)
using namespace std;

const int maxn=5e4+10;
const int maxm=555;

typedef __int64 LL;

int n,k;
LL dp[maxn][maxm];
vector<int> G[maxn];

void dfs1(int u,int fa) {
	dp[u][0]=1;
	for(int i=0;i<G[u].size();i++){
		int v=G[u][i];
		if(v==fa) continue;
		dfs1(v,u);
		for(int j=0;j+1<=k;j++){
			dp[u][j+1]+=dp[v][j];
		}
	}
}

LL tmp[maxm];
void dfs2(int u,int fa) {
	if(fa!=-1){
		tmp[0]=dp[fa][0];
		for(int j=1;j<=k;j++){
			tmp[j]=dp[fa][j]-dp[u][j-1];
		}
		for(int j=0;j+1<=k;j++){
			dp[u][j+1]+=tmp[j];
		}
	}
	for(int i=0;i<G[u].size();i++){
		int v=G[u][i];
		if(v==fa) continue;
		dfs2(v,u);
	}
}

int main() {
	scanf("%d%d",&n,&k);
	memset(dp,0,sizeof(dp));
	for(int i=0; i<n-1; i++) {
		int u,v;
		scanf("%d%d",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u); 
	}
	dfs1(1,-1);
	dfs2(1,-1);
	LL ans=0;
	for(int i=1;i<=n;i++){
		ans+=dp[i][k];
	}
	printf("%I64d\n",ans/2);
	return 0;
}
posted @ 2016-07-29 22:27  fenicnn  阅读(122)  评论(0编辑  收藏  举报