树的分治学习

 

poj 1741 http://poj.org/problem?id=1741

题意:求树上距离小于k的点对的对数;

分析:每条路经要么过根,要么不过根,对于不过根的路径我们递归同理可以求出,

对于过根的路径,我们dfs一遍记录下所有其他节点到跟的距离,然后sort()一下可以o(n)求出

其小于k的pair,  但我们要减去其中来自同一子树的pair;

利用重心分治,最多logn次,所有时间o(n * logn * logn);

重心的定义:去掉该点后,其子树节点个数的最大值最小;

step1:对于当前树,找到其重心,统计pair;

step2:把重心当根,分别递归其子树;

findsz(),findC()为求重心;

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 #include<cmath>
 5 #include<vector>
 6 #include<cstdlib>
 7 #define MP make_pair
 8 using namespace std;
 9 typedef pair<int,int> pii;
10 const int N = 10000+10;
11 vector<pii> g[N];
12 int n,KK;
13 
14 int cn[N];
15 int vis[N];
16 int findsz(int u,int fa) {
17     int ret = 1;
18     int sz = g[u].size();
19     for (int i = 0; i < sz; i++) {
20         int c = g[u][i].first;
21         if (c == fa || vis[c]) continue;
22         ret += findsz(c,u);
23     }
24     return ret;
25 }
26 void findC(int u,int fa,int &k,int &mark,int nn) {
27     int mx = 0;
28     int sz = g[u].size();
29     cn[u] = 1;
30     for (int i = 0; i < sz; i++) {
31         int c = g[u][i].first;
32         if (c == fa || vis[c]) continue;
33         findC(c,u,k,mark,nn);
34         cn[u] += cn[c];
35         if (cn[c] > mx) mx = cn[c];
36     }
37     if (nn - cn[u] > mx) mx = nn - cn[u];
38     if (mark == -1 || mx < mark) mark = mx,k = u;
39 }
40 int num[N];
41 int cnt;
42 void findnum(int u,int fa,int dep) {
43     num[cnt++] = dep;
44     int sz = g[u].size();
45     for (int i = 0; i < sz; i++) {
46         int c = g[u][i].first;
47         if (vis[c] || c == fa) continue;
48         findnum(c,u,dep+g[u][i].second);
49     }
50 }
51 int calc(int k,int w) {
52     cnt = 0;
53     findnum(k,0,w);
54     sort(num,num+cnt);
55     int r = cnt-1;
56     int ret = 0;
57     for (int i = 0; i < r; i++) {
58         while (num[i] + num[r] > KK && i < r) r--;
59         ret += r - i;
60     }
61     return ret;
62 }
63 int ans;
64 void dfs(int u,int w) {
65     int nn = findsz(u,0);
66     int k = 0, mark = -1;
67     findC(u,0,k,mark,nn);
68     int sz = g[k].size();
69     vis[k] = 1;
70     ans += calc(k,0);
71 
72     for (int i = 0; i < sz; i++) {
73         int c = g[k][i].first;
74         if (vis[c]) continue;
75         ans -= calc(c,g[k][i].second);
76         dfs(c,g[k][i].second);
77     }
78 }
79 void solve(){
80     memset(vis,0,sizeof(vis));
81     ans = 0;
82     dfs(1,0);
83     printf("%d\n",ans);
84 }
85 int main(){
86     while (~scanf("%d%d",&n,&KK),n+KK) {
87         for (int i = 0; i <= n; i++) g[i].clear();
88         for (int i = 0; i < n-1; i++) {
89             int u,v,w; scanf("%d%d%d",&u,&v,&w);
90             g[u].push_back(MP(v,w));
91             g[v].push_back(MP(u,w));
92         }
93         solve();
94     }
95 
96     return 0;
97 }
View Code

 

 

posted @ 2013-11-19 21:36  Rabbit_hair  阅读(409)  评论(0编辑  收藏  举报