K shortest path on tree 题解
一、题目:
二、思路:
个人觉得这是一道经典的套路题,需要牢记这种处理方式。
其实这道题一个非常重要的性质就是:树的形态随机。在这种情况下,有一个很重要的结论:树高期望为 \(\log n\)。或者我们可以理解成,这棵树期望的样子是一棵完全二叉树。
现在来说一下具体的做法。我们发现,如果我们固定一个点 \(x\),那么 \(\forall y\),都有 \(dis(x,y)=dep_x+dep_y-2\times\mathbb{lca}(x,y)\)。这就启示我们,如果在访问到 \(x\) 时,我们已经对于 \(x\) 的所有祖先 \(u\),都在一棵权值线段树保存了所有 \(dep_y-2\times dep_u\) 的值,那么求解 \(x\) 的答案就只需要在线段树上跑一个二分。(当然,如果你不嫌常数大,平衡树也是可以的!)
接下来的问题就是,我们该如何维护这棵不断变化的线段树呢?
假设现在访问到了节点 \(x\)。
- 对于 \(x\) 子树内的所有节点 \(p\),都把 \(dep_p-2\times dep_x\) 插入到线段树中。
- 在线段树上跑二分,求解 \(x\) 的答案。
- 对于 \(x\) 的一个儿子 \(y\)。
- 对于 \(y\) 子树内的所有节点 \(q\),都把 \(dep_q-2\times dep_x\) 从线段树中删除。
- 递归求解 \(y\)。
- 对于 \(y\) 子树内的所有节点 \(q\),都把 \(dep_q-2\times dep_x\) 重新插入到线段树中。
- 对于 \(x\) 子树内的所有节点 \(p\),都把 \(dep_p-2\times dep_x\) 从线段树中删除。
由于树高期望为 \(\log n\),所以一个节点最多会被加入线段树 \(O(\log n)\) 次,一次插入的复杂度是 \(O(\log n)\),所以最终的复杂度是 \(O(n\log^2 n)\)。
三、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define FILEIN(s) freopen(s, "r", stdin)
#define FILEOUT(s) freopen(s, "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return f * x;
}
const int MAXN = 2e5 + 5, LOGMAXN = 18, LEN = 200000;
int n, K;
int head[MAXN], tot, ans[MAXN];
int son[MAXN * LOGMAXN * LOGMAXN][2]; // 我为了省空间,用了动态开点线段树。
int sum[MAXN * LOGMAXN * LOGMAXN], sz = 1;
int dep[MAXN];
struct Edge {
int y, next, w;
Edge() {}
Edge(int _y, int _next, int _w) : y(_y), next(_next), w(_w) {}
}e[MAXN << 1];
inline void connect(int x, int y, int w) {
e[++tot] = Edge(y, head[x], w);
head[x] = tot;
}
inline void pushup(int o) {
sum[o] = sum[son[o][0]] + sum[son[o][1]];
}
inline void insert(int o, int l, int r, int q, int v) {
if (l == r) { sum[o] += v; return; }
int mid = (l + r) >> 1;
if (q <= mid) {
if (!son[o][0]) son[o][0] = ++sz;
insert(son[o][0], l, mid, q, v);
}
else {
if (!son[o][1]) son[o][1] = ++sz;
insert(son[o][1], mid + 1, r, q, v);
}
pushup(o);
}
int query(int o, int l, int r, int k) {
if (l == r) { return l; }
int mid = (l + r) >> 1;
if (sum[son[o][0]] >= k) return query(son[o][0], l, mid, k);
else return query(son[o][1], mid + 1, r, k - sum[son[o][0]]);
}
void prework(int x, int fa) {
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].y;
if (y == fa) continue;
dep[y] = dep[x] + e[i].w;
prework(y, x);
}
}
void modify(int x, int fa, int v, int t) {
insert(1, -LEN, LEN, dep[x] + v, t);
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].y;
if (y == fa) continue;
modify(y, x, v, t);
}
}
void dfs(int x, int fa) {
modify(x, fa, -2 * dep[x], 1);
ans[x] = query(1, -LEN, LEN, K) + dep[x];
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].y;
if (y == fa) continue;
modify(y, x, -2 * dep[x], -1);
dfs(y, x);
modify(y, x, -2 * dep[x], 1);
}
modify(x, fa, -2 * dep[x], -1);
}
int main() {
FILEIN("path.in"); FILEOUT("path.out");
n = read(); K = read();
for (int i = 1; i < n; ++i) {
int x = read(), y = read(), w = read();
connect(x, y, w); connect(y, x, w);
}
prework(1, 0);
dfs(1, 0);
for (int x = 1; x <= n; ++x) printf("%d\n", ans[x]);
return 0;
}