树的分治学习
poj 1741 http://poj.org/problem?id=1741
题意:求树上距离小于k的点对的对数;
分析:每条路经要么过根,要么不过根,对于不过根的路径我们递归同理可以求出,
对于过根的路径,我们dfs一遍记录下所有其他节点到跟的距离,然后sort()一下可以o(n)求出
其小于k的pair, 但我们要减去其中来自同一子树的pair;
利用重心分治,最多logn次,所有时间o(n * logn * logn);
重心的定义:去掉该点后,其子树节点个数的最大值最小;
step1:对于当前树,找到其重心,统计pair;
step2:把重心当根,分别递归其子树;
findsz(),findC()为求重心;
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #include<cmath> 5 #include<vector> 6 #include<cstdlib> 7 #define MP make_pair 8 using namespace std; 9 typedef pair<int,int> pii; 10 const int N = 10000+10; 11 vector<pii> g[N]; 12 int n,KK; 13 14 int cn[N]; 15 int vis[N]; 16 int findsz(int u,int fa) { 17 int ret = 1; 18 int sz = g[u].size(); 19 for (int i = 0; i < sz; i++) { 20 int c = g[u][i].first; 21 if (c == fa || vis[c]) continue; 22 ret += findsz(c,u); 23 } 24 return ret; 25 } 26 void findC(int u,int fa,int &k,int &mark,int nn) { 27 int mx = 0; 28 int sz = g[u].size(); 29 cn[u] = 1; 30 for (int i = 0; i < sz; i++) { 31 int c = g[u][i].first; 32 if (c == fa || vis[c]) continue; 33 findC(c,u,k,mark,nn); 34 cn[u] += cn[c]; 35 if (cn[c] > mx) mx = cn[c]; 36 } 37 if (nn - cn[u] > mx) mx = nn - cn[u]; 38 if (mark == -1 || mx < mark) mark = mx,k = u; 39 } 40 int num[N]; 41 int cnt; 42 void findnum(int u,int fa,int dep) { 43 num[cnt++] = dep; 44 int sz = g[u].size(); 45 for (int i = 0; i < sz; i++) { 46 int c = g[u][i].first; 47 if (vis[c] || c == fa) continue; 48 findnum(c,u,dep+g[u][i].second); 49 } 50 } 51 int calc(int k,int w) { 52 cnt = 0; 53 findnum(k,0,w); 54 sort(num,num+cnt); 55 int r = cnt-1; 56 int ret = 0; 57 for (int i = 0; i < r; i++) { 58 while (num[i] + num[r] > KK && i < r) r--; 59 ret += r - i; 60 } 61 return ret; 62 } 63 int ans; 64 void dfs(int u,int w) { 65 int nn = findsz(u,0); 66 int k = 0, mark = -1; 67 findC(u,0,k,mark,nn); 68 int sz = g[k].size(); 69 vis[k] = 1; 70 ans += calc(k,0); 71 72 for (int i = 0; i < sz; i++) { 73 int c = g[k][i].first; 74 if (vis[c]) continue; 75 ans -= calc(c,g[k][i].second); 76 dfs(c,g[k][i].second); 77 } 78 } 79 void solve(){ 80 memset(vis,0,sizeof(vis)); 81 ans = 0; 82 dfs(1,0); 83 printf("%d\n",ans); 84 } 85 int main(){ 86 while (~scanf("%d%d",&n,&KK),n+KK) { 87 for (int i = 0; i <= n; i++) g[i].clear(); 88 for (int i = 0; i < n-1; i++) { 89 int u,v,w; scanf("%d%d%d",&u,&v,&w); 90 g[u].push_back(MP(v,w)); 91 g[v].push_back(MP(u,w)); 92 } 93 solve(); 94 } 95 96 return 0; 97 }