P4178 Tree
P4178 Tree
题目描述
给定一棵 n 个节点的树,每条边有边权,求出树上两点距离小于等于 k 的点对数量。
输入格式
第一行输入一个整数 n,表示节点个数。
第二行到第 n 行每行输入三个整数 u,v,w ,表示 u 与 v 有一条边,边权是 w。
第 n+1 行一个整数 k 。
输出格式
一行一个整数,表示答案。
输入输出样例
输入 #1
7
1 6 13
6 3 9
3 5 7
4 1 3
2 4 20
4 7 2
10
输出 #1
5
说明/提示
数据规模与约定
对于全部的测试点,保证:
-
1≤n≤\(4×10^4\)
-
1≤u,v≤n
-
0≤w≤\(10^3\)
-
0≤k≤\(2×10^4\)
经典的点分治题目。
我们先处理出这个联通块内所有点到重心的距离,将所有距离排个序,用双指针,一个从小往大找,另一个从大往下找,这样可以快速的统计出小于等于k的路径,注意要减去同一颗子树内的答案。
#include <iostream> #include <cstdio> #include <cctype> #include <algorithm> using namespace std; inline long long read() { long long s = 0, f = 1; char ch; while(!isdigit(ch = getchar())) (ch == '-') && (f = -f); for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48)); return s * f; } const int N = 4e4 + 5, inf = 1e9; int n, x, y, z, k, cnt, num, ans, root, totsize; int a[N], siz[N], dis[N], vis[N], max_siz[N], head[N]; struct edge { int to, nxt, val; } e[N << 1]; void add(int x, int y, int z) { e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y; e[cnt].val = z; } void init() { n = read(); for(int i = 1;i <= n - 1; i++) { x = read(); y = read(); z = read(); add(x, y, z); add(y, x, z); } k = read(); } void get_root(int x, int fa) { siz[x] = 1; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fa || vis[y]) continue; get_root(y, x); siz[x] += siz[y]; max_siz[x] = max(max_siz[x], siz[y]); } max_siz[x] = max(max_siz[x], totsize - siz[x]); if(max_siz[x] < max_siz[root]) root = x; } void calc_dis(int x, int fa) { a[++num] = dis[x]; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fa || vis[y]) continue; dis[y] = dis[x] + e[i].val; calc_dis(y, x); } } void solve(int x) { vis[x] = 1; dis[x] = num = 0; calc_dis(x, 0); sort(a + 1, a + num + 1); int r = num; for(int i = 1;i <= num; i++) { while(a[i] + a[r] > k && r >= i) r--; if(r < i) break; ans += (r - i); } for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(vis[y]) continue; dis[y] = e[i].val; num = 0; calc_dis(y, x); sort(a + 1, a + num + 1); r = num; for(int j = 1;j <= num; j++) { while(a[j] + a[r] > k && r >= j) r--; if(r < j) break; ans -= (r - j); } max_siz[root = 0] = inf; totsize = siz[y]; get_root(y, 0); solve(root); } } void work() { max_siz[root = 0] = inf; totsize = n; get_root(1, 0); solve(root); printf("%d", ans); } int main() { init(); work(); return 0; }