[POJ1741]Tree
我们知道,树上两个点的LCA要么是当前根节点,要么不是。。所以两个点间的最短路径要么经过当前根节点,要么在一棵当前根节点的子树中。。
考虑点分治,于是在原来同一子树中的两个点必然在一次分治中变为路径经过当前根节点的两个点。
处理路径经过当前根节点的两个点的情况。对于当前树,每个节点(根节点除外)记录深度\(dep_i\)(根节点深度为\(0\))和除当前根节点外的最远祖先\(fa_i\)。。
于是有:
\[\sum [fa_i\ne fa_j \land dep_i+dep_j \le K]
\]
显然,式子等于:
\[\sum [dep_i+dep_j\le K]-\sum[fa_i=fa_j\land dep_i+dep_j\le K]
\]
于是可以这样解决:
在当前树中,将\(dep\)排序,用\(l\)表示左指针,\(r\)表示右指针,\(l\)从左向右遍历。如果\(dep_l+dep_r\le k\),则点对\((l,t)(i<t\le r)\)都符合题意,于是将\(r-l\)加入答案中,并且\(l\)++;否则\(r\)--。
需要注意的是链的情况。。时间复杂度会退化成\(O(N^2)\)。。我们可以将重心作为根,以保证复杂度为\(O(Nlog^2N)\)
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 10005;
struct edge {
int v, l;
edge(int v_, int l_) :v(v_), l(l_) {};
};
vector<edge> g[MAXN];
vector<int> dep;
int n, k, dist[MAXN], vis[MAXN], f[MAXN], root, ans, s[MAXN], tot;
inline void getroot(int now, int fa) {
int u;
s[now] = 1, f[now] = 0;
for (int i = 0; i < g[now].size(); i++) {
u = g[now][i].v;
if (u != fa && !vis[u])
getroot(u, now),
s[now] += s[u],
f[now] = max(f[now], s[u]);
}
f[now] = max(f[now], tot - s[now]);
if (f[now] < f[root]) root = now;
}
inline void getdep(int now, int fa) {
int u;
dep.push_back(dist[now]),
s[now] = 1;
for (int i = 0; i < g[now].size(); i++) {
u = g[now][i].v;
if (u != fa && !vis[u])
dist[u] = dist[now] + g[now][i].l,
getdep(u, now),
s[now] += s[u];
}
}
inline int calc(int now, int len) {
dep.clear(),dist[now] = len;
getdep(now, 0),
sort(dep.begin(), dep.end());
int cnt = 0, l = 0, r = dep.size() - 1;
while (l < r)
if (dep[l] + dep[r] <= k) cnt += r - l, l++;
else r--;
return cnt;
}
inline void work(int now) {
int u;
ans += calc(now, 0),
vis[now] = true;
for (int i = 0; i < g[now].size(); i++) {
u = g[now][i].v;
if (!vis[u])
ans -= calc(u, g[now][i].l),
f[0] = tot = s[u],
root = 0,
getroot(u, 0),
work(root);
}
}
int main() {
while (~scanf("%d%d",&n,&k)) {
if (!n && !k) break;
for (int i = 0; i <= n; i++) g[i].clear();
memset(vis, 0, sizeof(int)*(n+1));
int u, v, l;
for (int i = 1; i < n; i++)
scanf("%d%d%d",&u,&v,&l),
g[u].push_back(edge(v, l)),g[v].push_back(edge(u, l));
f[0] = n,root = 0,tot = n;
getroot(1, 0),
ans = 0,
work(root),
printf("%d\n",ans);
}
}