poj1741 Tree
Tree
Time Limit: 1000MS | Memory Limit: 30000K | |
Total Submissions: 25816 | Accepted: 8586 |
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0
Sample Output
8
Source
题目大意:求树上距离 ≤ k的点对数.
分析:点分治模板题.大体步骤就是找重心,然后求跨过重心的答案,接下来对重心的每个子树进行分治.每次找重心和距离都必须判断当前点是否走过,不然可能会再次回到重心.
#include <vector> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 20010; int head[maxn],to[maxn],nextt[maxn],w[maxn],tot = 1,vis[maxn],d[maxn]; int ans,root,sizee[maxn],f[maxn],sum,k,n; vector <int> q; void add(int x,int y,int z) { w[tot] = z; to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void getroot(int u,int fa) { f[u] = 0; sizee[u] = 1; for (int i = head[u];i;i = nextt[i]) { int v = to[i]; if (v == fa || vis[v]) continue; getroot(v,u); sizee[u] += sizee[v]; f[u] = max(f[u],sizee[v]); } f[u] = max(f[u],sum - sizee[u]); if (f[u] < f[root]) root = u; } void getdep(int u,int fa,int p) { d[u] = p; q.push_back(d[u]); for (int i = head[u];i; i = nextt[i]) { int v = to[i]; if (v == fa || vis[v]) continue; getdep(v,u,p + w[i]); } } int calc(int u,int p) { int res = 0; q.clear(); getdep(u,0,p); sort(q.begin(),q.end()); int l = 0,r = q.size() - 1; while (l < r) { if (q[l] + q[r] <= k) res += r - l++; else r--; } return res; } void dfs(int u) { vis[u] = 1; ans += calc(u,0); for (int i = head[u]; i ; i = nextt[i]) { int v = to[i]; if (!vis[v]) { ans -= calc(v,w[i]); f[0] = sum = sizee[v]; getroot(v,root = 0); dfs(root); } } } int main() { while (scanf("%d%d",&n,&k) == 2 && n + k) { memset(head,0,sizeof(head)); memset(vis,0,sizeof(vis)); tot = 1; ans = 0; for (int i = 1; i < n; i++) { int a,b,c; scanf("%d%d%d",&a,&b,&c); add(a,b,c); add(b,a,c); } f[0] = sum = n; root = 0; getroot(1,0); dfs(root); printf("%d\n",ans); } return 0; }