POJ 1741 Tree

题意:给出一颗树,求出两个点的距离小于K的点对的个数.

sL : 树的点分治,代码中有解释。

  1 #include <cstdio>

 2 #include <cstring>
 3 #include <algorithm>
 4 #include <vector>
 5 #include <set>
 6 #include <iostream>
 7 using namespace std;
 8 typedef long long LL;
 9 const int MAX = 1e4+10;
10 struct node {
11     int v,L;
12     node () {};
13     node (int _v,int _L) : v(_v) , L(_L) {};
14 };
15 vector<node> G[MAX]; // 邻接边
16 vector<int> D; // 存储子树节点到根的深度
17 LL sz[MAX],f[MAX],root,res,cut[MAX];  
18 //孩子的个数 | 以当前的点分治孩子最多的一侧 | 根  | 结果 | 分治割点
19 int size,dep[MAX];  // 当前的根的子树大小 | 节点的深度
20 int n,k,a,b,c;  // 题目所需
21 void get_root(int u,int pre) {   // 以当前的根遍历求取 sz 和 f
22     sz[u]=1; f[u]=0;
23     for(int i=0;i<G[u].size();i++) {
24         node x=G[u][i];
25         if(x.v==pre||cut[x.v]) continue;
26         get_root(x.v,u);
27         sz[u]+=sz[x.v];
28         f[u]=max(f[u],sz[x.v]);
29     }
30     f[u]=max(f[u],size-f[u]);
31     if(f[root]>f[u])  root=u;
32 }
33 void get_dep(int u,int pre) {  //求每个节点到根的深度
34     D.push_back(dep[u]);
35     for(int i=0;i<G[u].size();i++) {
36         node x=G[u][i];
37         if(cut[x.v]||x.v==pre) continue;
38         dep[x.v]=dep[u]+x.L;
39         get_dep(x.v,u);
40 
41     }
42 }
43 LL solve(int u,int now) {  //统计 dep[i]+dep[j]=k 的对数
44     dep[u]=now; D.clear();
45     get_dep(u,-1);
46     LL ret=0;
47     sort(D.begin(),D.end());
48     for (int l=0,r=D.size()-1;l<r; )
49         if (D[l] + D[r] <= k) ret += r-l++;
50         else r--;
51     return ret;
52 }
53 void gao(int u) {
54     res+=solve(u,0);
55     cut[u]=true;
56     int v;
57     for(int i=0;i<G[u].size();i++) {
58         if (!cut[v = G[u][i].v]) {
59             res -= solve(v, G[u][i].L);  //减去在一颗子树内重复的情况
60             f[0] = size = sz[v];
61             get_root(v, root=0);
62             gao(root);
63         }
64     }
65 }
66 int main() {
67 
68     while(scanf("%d %d",&n,&k)==2&&(n||k)) {
69         memset(cut,0,sizeof(cut));
70         memset(sz,0,sizeof(sz));
71         for(int i=0;i<MAX;i++) G[i].clear();
72         for(int i=1;i<n;i++) {
73             scanf("%d %d %d",&a,&b,&c);
74             G[a].push_back(node(b,c));
75             G[b].push_back(node(a,c));
76         }
77         res=0; root=0;
78         f[0]=size=n;
79         get_root(1,0);
80         gao(root);
81         cout<<res<<endl;
82     }
83 }
posted @ 2014-09-25 20:55  acvc  阅读(154)  评论(0编辑  收藏  举报