[POJ1741] Tree【树分治 点分治】
传送门:http://poj.org/problem?id=1741
写的第一道树分治题,撒花纪念~
对于每一对点对(i, j),它有三种情况:
① 其中一个是根节点。这种情况比较简单,直接加上就好了。
② 横跨根节点。这种情况是重点。
③ 不是以上两种情况。这时递归下去求解就好了。
那么对于第二种情况该怎么破呢?设根节点为root,那么dist(i, root) + dist(j, root) <= k,且需要i与j在不同的子树里。直接算不同子树的点对(i, j)的个数会麻烦,所以需要一点技巧:符合条件且在不同子树的(i, j)的对数 = 符合条件的对数 - 符合条件且在相同子树的(i, j)的对数,这样就搞定啦!
#include <cstdio> #include <cstring> #include <algorithm> const int maxn = 10005; int n, k, t1, t2, t3, ans; int head[maxn], to[maxn << 1], next[maxn << 1], w[maxn << 1], lb; int siz[maxn], a[maxn], left, right; bool book[maxn]; inline void ist(int aa, int ss, int ww) { to[lb] = ss; next[lb] = head[aa]; head[aa] = lb; w[lb] = ww; ++lb; } int fnd_zx(int fr, int tot_node, int p, int & rt, int & mn) { int mx = 0; for (int j = head[fr]; j != -1; j = next[j]) { if (!book[to[j]] && to[j] != p) { fnd_zx(to[j], tot_node, fr, rt, mn); mx = std::max(mx, siz[to[j]]); } } mx = std::max(mx, tot_node - siz[fr]); if (mn > mx) { mn = mx; rt = fr; } } void get_siz(int fr, int p) { siz[fr] = 1; for (int j = head[fr]; j != -1; j = next[j]) { if (!book[to[j]] && to[j] != p) { get_siz(to[j], fr); siz[fr] += siz[to[j]]; } } } void get_data(int r, int p, int ww) { if (ww > k) { return; } a[right++] = ww; for (int j = head[r]; j != -1; j = next[j]) { if (!book[to[j]] && to[j] != p) { get_data(to[j], r, ww + w[j]); } } } int get_ans(int l, int r) { std::sort(a + l, a + r); int rt = 0; --r; while (r > l) { while (r > l && a[l] + a[r] > k) { --r; } rt += r - l; ++l; } return rt; } void slove(int fr) { int root = -666, mn = 2147483647; get_siz(fr, 0); fnd_zx(fr, siz[fr], 0, root, mn); book[root] = true; for (int j = head[root]; j != -1; j = next[j]) { if (!book[to[j]]) { slove(to[j]); } } left = right = 0; for (int j = head[root]; j != -1; j = next[j]) { if (!book[to[j]]) { get_data(to[j], root, w[j]); ans -= get_ans(left, right); left = right; } } ans += get_ans(0, right) + right; book[root] = false; } int main(void) { //freopen("in.txt", "r", stdin); while (scanf("%d%d", &n, &k) && n && k) { lb = 0; memset(head, -1, sizeof head); memset(next, -1, sizeof next); ans = 0; for (int i = 1; i < n; ++i) { scanf("%d%d%d", &t1, &t2, &t3); ist(t1, t2, t3); ist(t2, t1, t3); } slove(1); printf("%d\n", ans); } return 0; }