POJ 1741:Tree(树上点分治)
题意
给一棵边带权树,问两点之间的距离小于等于K的点对有多少个。
思路
图片转载于http://www.cnblogs.com/Paul-Guderian/p/6782671.html
我对于点分治的理解:对于树上的一些问题,可以转化为答案只与当前根有关的问题,然后分治递归求解每一棵子树,统计答案。找的根应当是当前子树的重心,具体证明可以看上面的论文。
对于当前正在处理的树,这棵树的路径有两种情况:
-
经过根结点。
-
不经过根节点(在子树内)。
对于第二种情况, 我们可以递归求解转化为第一种情况来处理。于是问题变成求解第一种情况了。
这道题在cal统计答案的时候,因为我们在处理以 root
为根节点的子树的答案贡献的时候,求的是在不同子树中的距离小于等于k的点对(第一种情况),但是我们cal出来的是两种情况都包括的,因此需要减去第二种情况,即再cal一遍处理子树。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 1e5 + 10;
typedef long long LL;
struct Edge {
int v, nxt, w;
} edge[N*2];
int n, k, head[N], tot, dep[N], son[N], dis[N], f[N], vis[N], sum, root, ans;
void Add(int u, int v, int w) {
edge[tot] = (Edge) { v, head[u], w }; head[u] = tot++;
edge[tot] = (Edge) { u, head[v], w }; head[v] = tot++;
}
void getroot(int u, int fa) { // 找重心
son[u] = 1; f[u] = 0;
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v;
if(v == fa || vis[v]) continue;
getroot(v, u);
son[u] += son[v];
f[u] = max(f[u], son[v]); // 最大的子树
}
// 当前的树中除了以u为根的树以外的结点数
// 因为当以u为根的话,除了u为根的树的结点之外的所有结点在一个子树里面
f[u] = max(f[u], sum - son[u]);
// 找一个根节点使得最大的子树最小
if(f[u] < f[root]) root = u;
}
void getdeep(int u, int fa) {
// 处理出dep数组,也是当前点到根节点的距离的数组,dep[0]表示数量
dep[++dep[0]] = dis[u];
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v, w = edge[i].w;
if(vis[v] || v == fa) continue;
dis[v] = dis[u] + w;
getdeep(v, u);
}
}
int cal(int u, int now) {
dep[0] = 0, dis[u] = now;
getdeep(u, 0);
sort(dep + 1, dep + 1 + dep[0]);
int res = 0, l = 1, r = dep[0];
while(l < r) {
// 对于连着l和r的两个端点,之间的所有点都可以使得距离小于等于k
if(dep[l] + dep[r] <= k) res += r - l, l++;
else r--;
} return res;
}
void work(int u) {
// 计算满足dep(i)+dep(j)<=k的数目
ans += cal(u, 0);
vis[u] = 1;
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v, w = edge[i].w;
if(vis[v]) continue;
// 减去满足dep(i)+dep(j)<=k并且i和j在同一个子树的数目(第二种情况)
ans -= cal(v, w);
sum = son[v];
getroot(v, root = 0); // 递归处理子树
// printf("root : %d\n", root);
work(root);
}
}
int main() {
while(~scanf("%d%d", &n, &k), n + k) {
memset(head, -1, sizeof(head)); tot = 0;
memset(vis, 0, sizeof(vis));
for(int i = 1; i < n; i++) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
Add(u, v, w);
}
sum = n, f[0] = INF, ans = 0, root = 0;
getroot(1, 0);
// printf("root : %d\n", root);
work(root);
printf("%d\n", ans);
}
return 0;
}
/*
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
*/