[POJ1741] Tree【树分治 点分治】

传送门:http://poj.org/problem?id=1741

写的第一道树分治题,撒花纪念~

对于每一对点对(i, j),它有三种情况:

① 其中一个是根节点。这种情况比较简单,直接加上就好了。

② 横跨根节点。这种情况是重点。

③ 不是以上两种情况。这时递归下去求解就好了。

那么对于第二种情况该怎么破呢?设根节点为root,那么dist(i, root) + dist(j, root) <= k,且需要i与j在不同的子树里。直接算不同子树的点对(i, j)的个数会麻烦,所以需要一点技巧:符合条件且在不同子树的(i, j)的对数 = 符合条件的对数 - 符合条件且在相同子树的(i, j)的对数,这样就搞定啦!

 

#include <cstdio>
#include <cstring>
#include <algorithm>

const int maxn = 10005;

int n, k, t1, t2, t3, ans;
int head[maxn], to[maxn << 1], next[maxn << 1], w[maxn << 1], lb;
int siz[maxn], a[maxn], left, right;
bool book[maxn];

inline void ist(int aa, int ss, int ww) {
	to[lb] = ss;
	next[lb] = head[aa];
	head[aa] = lb;
	w[lb] = ww;
	++lb;
}
int fnd_zx(int fr, int tot_node, int p, int & rt, int & mn) {
	int mx = 0;
	for (int j = head[fr]; j != -1; j = next[j]) {
		if (!book[to[j]] && to[j] != p) {
			fnd_zx(to[j], tot_node, fr, rt, mn);
			mx = std::max(mx, siz[to[j]]);
		}
	}
	mx = std::max(mx, tot_node - siz[fr]);
	if (mn > mx) {
		mn = mx;
		rt = fr;
	}
}
void get_siz(int fr, int p) {
	siz[fr] = 1;
	for (int j = head[fr]; j != -1; j = next[j]) {
		if (!book[to[j]] && to[j] != p) {
			get_siz(to[j], fr);
			siz[fr] += siz[to[j]];
		}
	}
}
void get_data(int r, int p, int ww) {
	if (ww > k) {
		return;
	}
	a[right++] = ww;
	for (int j = head[r]; j != -1; j = next[j]) {
		if (!book[to[j]] && to[j] != p) {
			get_data(to[j], r, ww + w[j]);
		}
	}
}
int get_ans(int l, int r) {
	std::sort(a + l, a + r);
	int rt = 0;
	--r;
	while (r > l) {
		while (r > l && a[l] + a[r] > k) {
			--r;
		}
		rt += r - l;
		++l;
	}
	return rt;
}
void slove(int fr) {
	int root = -666, mn = 2147483647;
	get_siz(fr, 0);
	fnd_zx(fr, siz[fr], 0, root, mn);
	book[root] = true;
	for (int j = head[root]; j != -1; j = next[j]) {
		if (!book[to[j]]) {
			slove(to[j]);
		}
	}
	left = right = 0;
	for (int j = head[root]; j != -1; j = next[j]) {
		if (!book[to[j]]) {
			get_data(to[j], root, w[j]);
			ans -= get_ans(left, right);
			left = right;
		}
	}
	ans += get_ans(0, right) + right;
	book[root] = false;
}

int main(void) {
	//freopen("in.txt", "r", stdin);
	while (scanf("%d%d", &n, &k) && n && k) {
		lb = 0;
		memset(head, -1, sizeof head);
		memset(next, -1, sizeof next);
		ans = 0;
		for (int i = 1; i < n; ++i) {
			scanf("%d%d%d", &t1, &t2, &t3);
			ist(t1, t2, t3);
			ist(t2, t1, t3);
		}
		slove(1);
		printf("%d\n", ans);
	}
	return 0;
}

 

  

 

posted @ 2016-11-21 20:09  ciao_sora  阅读(113)  评论(0编辑  收藏  举报