POJ_1741
具体的思路可以参考漆子超的《分治算法在树的路径问题中的应用》这篇论文。
对于树的分治的题目,一开始理解起来有点头疼,主要是一开始没弄明白为什么分治以后可以改善复杂度,其实关键的操作就在于对于当前这棵树,重新找到一个“合适”(何为“合适”,详见论文)的点作为根节点,并将这个点删除,再依次递归这个点的各个子树。
如果我们将这些点按删除的顺序重新构造一棵树的话,论文上有证明这棵树的树高是O(logN)的,对于每一层的节点而言,遍历一遍这层节点所有的子树的总复杂度为O(N)。这样,对于这棵树每个节点都执行一遍遍历所有子树的操作的总的复杂度就是O(NlogN)了,从而达到了降低复杂度的目的。如果按原树的结构每个节点都执行一遍遍历所有子树的操作,最坏的复杂度是O(N^2),比如这棵树是一条链。
由于我们降低了遍历每个节点子树的复杂度,于是就在子树遍历的基础上得到了树的分治的算法。由于需要用到快排,对于每层节点而言,将所有子节点进行快排的总的复杂度是O(NlogN),相比遍历每层节点的子树而言多了就多了一个logN,因此总的复杂度是O(N*logN*logN)
需要注意的是,由于每次都对子树进行重新选根的操作,这样可以得到以newroot为根的各个depth值,在结束对这棵子树的递归时,需要先还原这棵子树原来形态下的depth值,后续的一些操作才能顺利进行,同样遍历一遍子树就可以得到新的depth值。
#include<stdio.h> #include<string.h> #include<algorithm> #define MAXD 10010 #define MAXM 20010 #define INF 0x3f3f3f3f using namespace std; int N, K, e, first[MAXD], next[MAXM], v[MAXM], w[MAXM]; int q[MAXD], fa[MAXD], p[MAXD], del[MAXD], size[MAXD], res; void add(int x, int y, int z) { v[e] = y, w[e] = z; next[e] = first[x], first[x] = e ++; } int Max(int x, int y) { return x > y ? x : y; } void init() { int i, x, y, z; e = 0; memset(first, -1, sizeof(first)); for(i = 1; i < N; i ++) { scanf("%d%d%d", &x, &y, &z); add(x, y, z), add(y, x, z); } } int findroot(int cur) { int i, j, x, rear = 0, max, min = INF, root; q[rear ++] = cur, fa[cur] = 0; for(i = 0; i < rear; i ++) { x = q[i]; for(j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != fa[x]) q[rear ++] = v[j], fa[v[j]] = x; } for(i = rear - 1; i >= 0; i --) { x = q[i]; size[x] = 1, max = 0; for(j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != fa[x]) size[x] += size[v[j]], max = Max(max, size[v[j]]); max = Max(max, rear - size[x]); if(max < min) min = max, root = x; } return root; } int deal(int s, int t) { int i, rear = t, ans = 0; for(i = s; i <= t; i ++) { while(rear >= s && p[rear] + p[i] > K) -- rear; ans += rear - s + 1; } return ans; } void renew(int s, int cur, int d) { int i, j, x, rear = 0; q[rear ++] = cur, fa[cur] = 0, p[s] = d; for(i = 0; i < rear; i ++) { x = q[i]; for(j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != fa[x]) p[s + rear] = p[s + i] + w[j], q[rear ++] = v[j], fa[v[j]] = x; } sort(p + s, p + s + rear); } int dfs(int s, int cur, int d) { int i, root, n, tot = 1; root = findroot(cur); del[root] = 1; for(i = first[root]; i != -1; i = next[i]) if(!del[v[i]]) { n = dfs(s + tot, v[i], w[i]); res -= deal(s + tot, s + tot + n - 1); tot += n; } p[s] = 0; sort(p + s, p + s + tot); res += deal(s, s + tot - 1) - 1; del[root] = 0; renew(s, cur, d); return tot; } void solve() { memset(del, 0, sizeof(del)); res = 0; dfs(0, 1, 0); printf("%d\n", res >> 1); } int main() { for(;;) { scanf("%d%d", &N, &K); if(!N && !K) break; init(); solve(); } return 0; }