【树形+点分治】POJ 1741 Tree
通道:http://poj.org/problem?id=1741
题意:有多少对[u,v]的距离小于K
思路:
将无根树转化成有根树进行观察。满足条件的点对有两种情况:两个点的路径横跨树根,两个点位于同一颗子树中。
如果我们已经知道了此时所有点到根的距离a[i],a[x] + a[y] <= k的(x, y)对数就是结果,这个可以通过排序之后O(n)的复杂度求出。然后根据分治的思想,分别对所有的儿子求一遍即可,但是这会出现重复的——当前情况下两个点位于一颗子树中,那么应该将其减掉(显然这两个点是满足题意的,为什么减掉呢?因为在对子树进行求解的时候,会重新计算)。在进行分治时,为了避免树退化成一条链而导致时间复杂度变为O(N^2),每次都找树的重心,这样,所有的子树规模就会变的很小了。时间复杂度O(Nlog^2N)。
代码:
#include <cstdio> #include <cmath> #include <cstring> #include <vector> #include <algorithm> using namespace std; typedef long long ll; const int MAX_N = 100007; struct Node { int v, w, nxt; Node () { } Node (int _v, int _w, int _n) { v = _v; w = _w; nxt = _n; } }; int n, k; int head[MAX_N], edgecnt; Node G[MAX_N << 1]; bool del[MAX_N]; void Clear() { edgecnt = 0; memset(head, -1, sizeof head); memset(del, 0, sizeof del); } void add(int u, int v, int w) { G[edgecnt] = Node(v, w, head[u]); head[u] = edgecnt++; } int son[MAX_N], opt[MAX_N]; vector<int> alln; void dfs(int u,int fa) { alln.push_back(u); son[u] = 1; opt[u] = 0; for(int i = head[u]; ~i; i = G[i].nxt) { int v = G[i].v; if(del[v] || v == fa) continue; dfs(v, u); son[u] += son[v]; opt[u] = max(opt[u], son[v]); } } int getCenter(int u) { alln.clear(); dfs(u, -1); int mx = 0, ans = -1; int sz = alln.size(); for(int i = 0; i < sz; ++i) { int v = alln[i]; if(ans == -1) ans = v, mx = max(opt[v], sz - son[v]); else { if(max(opt[v], sz - son[v]) < mx) { mx = max(opt[v], sz - son[v]); ans = v; } } } return ans; } int tot, D[MAX_N]; void getDist(int u, int fa, int w) { D[tot++] = w; for(int i = head[u]; ~i; i = G[i].nxt) { int v = G[i].v; if(del[v] || v == fa) continue; getDist(v, u, w + G[i].w); } } ll calc() { sort(D, D + tot); ll ans = 0; int l = 0, r = tot - 1; while (l < r) { if (D[l] + D[r] <= k) ans += r - l++; else --r; } return ans; } ll ans; void work(int u) { u = getCenter(u); tot = 0; getDist(u, -1, 0); ans += calc(); for(int i = head[u]; ~i; i = G[i].nxt) { int v = G[i].v; if(del[v]) continue; tot = 0; getDist(v, u, G[i].w); ans -= calc(); } del[u] = true; for(int i = head[u]; ~i; i =G[i].nxt) { int v = G[i].v; if(del[v]) continue; work(v); } } int main() { while (2 == scanf("%d%d", &n, &k)) { if (0 == n && 0 == k) break; Clear(); for (int i = 1; i < n; ++i) { int u, v, w; scanf("%d%d%d", &u, &v, &w); add(u, v, w), add(v, u, w); } ans = 0; work(1); printf("%lld\n", ans); } return 0; }