poj 1741 Tree 分治在树上应用
关于分治算法在树上的应用详情请查看09年QZC国家集训队论文。
题目大意: 树含N个点,点之间有权值,求两点间权值和小于等于K的点对数量( N <= 10000 )
解题思路:对于以rt为根节点的树,其树上两点间一条路径只有两种情况,分别为过根节点,不过根节点。
这样,启发了我们使用分治的思想来解决此题。
若不过根节点,则通过递归处理,其实也可理解为过根节点,但过了根的那部分为0.可简化代码
若过根节点, 则 dist(i)+dist(j) <= K 且 father(i) != father(j) 其中哦功能 dist(i) 为 子树上节点到根节点rt的距离, father(i)为子树上节点i 是属于 rt的哪一个子节点
转换下,我们需要的结果就是 dist(i) + dist)j) <= K -- dist(i) + dist(j) <= K 且 father(i) == father(j)
两部分都是要求 dist(i) + dist(j) <= K , 我们可以对根节点 rt 求一次 满足此条件的所有点对数量, 然后对根节点 rt的 所有字节点求一次( 直接子节点 ) 然后相减就是我们所求过根节点的点对数量
另, 对于当前子树 满足 dist(i) + dist(j) <= K 的点对数量, 我们可以将所有距离排序后, 通过两个下标不回溯的思想,达到O(n)的时间复杂度来统计,这样只需要Nlog(N)的时间复杂度。
另,计算完成当前根节点后,再递归到子节点继续求取结果, 总的时间复杂度为 N*logN*logN ,
这里要注意,每次递归到子树后,选出树的重心作为根节点可大幅度降低时间,这样最坏情况是N/2层,当树为链形达到极端情况。关于求得树的重心可以通过O(n)的算法。
解题代码:
#include<stdio.h> #include<string.h> #include<stdlib.h> #include<iostream> #include<algorithm> #include<vector> using namespace std; #define MIN(a,b) (a)<(b)?(a):(b) #define MAX(a,b) (a)>(b)?(a):(b) const int inf = 0x3fffffff; const int maxn = 10010; int N, K, ans; int D[maxn], M[maxn]; int head[maxn], idx; bool vis[maxn]; int key[maxn], cnt; struct Edge{ int v, c, next; }edge[maxn<<2]; vector<int> S[maxn], Q, P; void AddEdge( int u, int v, int c) { edge[idx].v = v; edge[idx].c = c; edge[idx].next = head[u]; head[u] = idx++; edge[idx].v = u; edge[idx].c = c; edge[idx].next = head[v]; head[v] = idx++; } void Input() { memset( head, 0xff, sizeof(head) ); idx = 0; int a, b, c; for(int i = 0; i < N-1; i++) { scanf("%d%d%d", &a,&b,&c); AddEdge( a, b, c ); } } // 取树的中心rt int getsum( int u,int pre ) { int tot = 1; for(int i = head[u]; ~i; i = edge[i].next ) if( !vis[edge[i].v] && edge[i].v != pre ) { int t = getsum( edge[i].v,u ); M[u] = MAX( M[u], t ); tot += t; } return tot; } int getrt( int u, int pre, int n ) { int key = u; M[u] = MAX( M[u], n-1-M[u] ); for(int i = head[u]; ~i; i = edge[i].next ) if( !vis[edge[i].v] && edge[i].v != pre ) { int t = getrt( edge[i].v, u, n ); if( M[t] < M[key] ) key = t; } return key; } int GetRt(int x) { memset( M, 0, sizeof(M) ); int n = getsum( x, 0 ); int rt = getrt( x, 0, n ); return rt; } //**************************** void GetDist( int u, int pre, int c ) { key[cnt++] = c; for(int i = head[u]; ~i; i = edge[i].next ) if( !vis[edge[i].v] && edge[i].v != pre && c+edge[i].c <= K ) GetDist(edge[i].v, u, c+edge[i].c ); } int Count( ) { sort( key, key+cnt ); int s = 0, l = 0, r = cnt-1; while( l < r ) { if( key[l]+key[r] <= K ) s += r-l, l++; else r--; } return s; } int GetS2( int u, int pre ) { //属于一颗子数上点 int tot = 0; for(int i = head[u]; ~i; i = edge[i].next ) { if( !vis[ edge[i].v ] && edge[i].v != pre ) { cnt = 0; GetDist( edge[i].v, u, edge[i].c ); tot += Count(); } } return tot; } void solve( int x, int pre ){ int rt = GetRt(x), s1, s2; cnt = 0; GetDist( rt, 0, 0 ); s1 = Count(), s2 = GetS2(rt, 0); ans += (s1-s2); // printf("ans = %d, rt = %d\n", ans, rt ); vis[rt] = true;//从图中删除此点 for(int i = head[rt]; ~i; i = edge[i].next ) ////递归到问题 if( !vis[ edge[i].v ] && edge[i].v != pre ) solve( edge[i].v, rt ); } int main() { while( scanf("%d%d", &N,&K) , N+K ) { Input(); memset( vis, 0, sizeof(vis) ); ans = 0; solve( 1, 0 ); printf("%d\n", ans ); } }