jzoj 1166. 树中点对距离
Description
给出一棵带边权的树,问有多少对点的距离<=\(len\)
100% \(2<=n<=10000,len<=maxlongint\)
Solution
这一题可以说是点分治的模板题了。
我们按照套路,先求重心,在计算答案。
如何计算答案?
设当前点为\(x\)
我们先\(O(n)\)搜一遍求出当前树的每个点的深度。
分类讨论:
经过\(x\),那么只需满足\(dep[a] + dep[b] <= len\)即可。
不经过\(x\),那么这个可以在那一颗子树中计算到。
所以,我们需要容斥,将不经过\(x\)的减掉。
由于他们到\(x\)还有相同的一段路径(设长\(length\)),所以要将他们的\(dep\)之和加上\(2*length\)。
Code
#include <cstdio>
#include <algorithm>
#define N 10010
using namespace std;
struct node{int v, fr, l;}e[N << 1];
int n, len, rt, cnt = 0, ans = 0, tot;
int tail[N], siz[N], son[N], dep[N], size = 0;
bool bz[N];
inline int read()
{
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return x;
}
void add(int u, int v, int l) {e[++cnt] = (node){v, tail[u], l}; tail[u] = cnt;}
void getrt(int x, int fa)
{
siz[x] = 1; son[x] = 0;
for (int p = tail[x], v; p; p = e[p].fr)
{
v = e[p].v;
if (v == fa || bz[v]) continue;
getrt(v, x);
if (son[x] < siz[v]) son[x] = siz[v];
siz[x] += siz[v];
}
if (son[x] < tot - siz[x]) son[x] = tot - siz[x];
if (son[x] < son[rt]) rt = x;
}
void getdeep(int x, int fa, int deep)
{
dep[++size] = deep;
for (int p = tail[x], v; p; p = e[p].fr)
{
v = e[p].v;
if (v == fa || bz[v]) continue;
getdeep(v, x, deep + e[p].l);
}
}
int cal(int x, int num)//calculation
{
size = 0, getdeep(x, 0, num);
sort(dep + 1, dep + size + 1);
int l = 1, r = size, s = 0;
while (l < r)
{
while (dep[l] + dep[r] <= len && l < r) s += r - l, l++;
while (dep[l] + dep[r] > len && l < r) r--;
}
return s;
}
void solve(int x)
{
// printf("%d\n", x);
bz[x] = 1;
ans += cal(x, 0);
for (int p = tail[x], v; p; p = e[p].fr)
{
v = e[p].v;
if (bz[v]) continue;
ans -= cal(v, e[p].l);
tot = siz[v], rt = 0, getrt(v, 0);
solve(rt);
}
}
int main()
{
freopen("distance.in", "r", stdin);
freopen("distance.out", "w", stdout);
n = read(), len = read(); son[0] = 1e9;
for (int i = 1, u, v, l; i < n; i++)
u = read(), v = read(), l = read(), add(u, v, l), add(v, u, l);
tot = n, rt = 0, getrt(1, 0);
solve(rt);
printf("%d\n", ans);
return 0;
}
转载需注明出处。