K shortest path on tree 题解

一、题目:


二、思路:

个人觉得这是一道经典的套路题,需要牢记这种处理方式。

其实这道题一个非常重要的性质就是:树的形态随机。在这种情况下,有一个很重要的结论:树高期望为 \(\log n\)。或者我们可以理解成,这棵树期望的样子是一棵完全二叉树。

现在来说一下具体的做法。我们发现,如果我们固定一个点 \(x\),那么 \(\forall y\),都有 \(dis(x,y)=dep_x+dep_y-2\times\mathbb{lca}(x,y)\)。这就启示我们,如果在访问到 \(x\) 时,我们已经对于 \(x\) 的所有祖先 \(u\),都在一棵权值线段树保存了所有 \(dep_y-2\times dep_u\) 的值,那么求解 \(x\) 的答案就只需要在线段树上跑一个二分。(当然,如果你不嫌常数大,平衡树也是可以的!)

接下来的问题就是,我们该如何维护这棵不断变化的线段树呢?

假设现在访问到了节点 \(x\)

  1. 对于 \(x\) 子树内的所有节点 \(p\),都把 \(dep_p-2\times dep_x\) 插入到线段树中。
  2. 在线段树上跑二分,求解 \(x\) 的答案。
  3. 对于 \(x\) 的一个儿子 \(y\)
    1. 对于 \(y\) 子树内的所有节点 \(q\),都把 \(dep_q-2\times dep_x\) 从线段树中删除。
    2. 递归求解 \(y\)
    3. 对于 \(y\) 子树内的所有节点 \(q\),都把 \(dep_q-2\times dep_x\) 重新插入到线段树中。
  4. 对于 \(x\) 子树内的所有节点 \(p\),都把 \(dep_p-2\times dep_x\) 从线段树中删除。

由于树高期望为 \(\log n\),所以一个节点最多会被加入线段树 \(O(\log n)\) 次,一次插入的复杂度是 \(O(\log n)\),所以最终的复杂度是 \(O(n\log^2 n)\)

三、代码:

#include <iostream> 
#include <cstdio>
#include <cstring>

using namespace std;
#define FILEIN(s) freopen(s, "r", stdin)
#define FILEOUT(s) freopen(s, "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)

inline int read(void) {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return f * x;
}

const int MAXN = 2e5 + 5, LOGMAXN = 18, LEN = 200000;

int n, K;
int head[MAXN], tot, ans[MAXN];

int son[MAXN * LOGMAXN * LOGMAXN][2]; // 我为了省空间,用了动态开点线段树。
int sum[MAXN * LOGMAXN * LOGMAXN], sz = 1;

int dep[MAXN];

struct Edge {
    int y, next, w;
    Edge() {}
    Edge(int _y, int _next, int _w) : y(_y), next(_next), w(_w) {}
}e[MAXN << 1];

inline void connect(int x, int y, int w) {
    e[++tot] = Edge(y, head[x], w);
    head[x] = tot;
}

inline void pushup(int o) {
    sum[o] = sum[son[o][0]] + sum[son[o][1]];
}

inline void insert(int o, int l, int r, int q, int v) {
    if (l == r) { sum[o] += v; return; }
    int mid = (l + r) >> 1;
    if (q <= mid) {
        if (!son[o][0]) son[o][0] = ++sz;
        insert(son[o][0], l, mid, q, v);
    }
    else {
        if (!son[o][1]) son[o][1] = ++sz;
        insert(son[o][1], mid + 1, r, q, v);
    }
    pushup(o);
}

int query(int o, int l, int r, int k) {
    if (l == r) { return l; }
    int mid = (l + r) >> 1;
    if (sum[son[o][0]] >= k) return query(son[o][0], l, mid, k);
    else return query(son[o][1], mid + 1, r, k - sum[son[o][0]]);
}

void prework(int x, int fa) {
    for (int i = head[x]; i; i = e[i].next) {
        int y = e[i].y;
        if (y == fa) continue;
        dep[y] = dep[x] + e[i].w;
        prework(y, x);
    }
}

void modify(int x, int fa, int v, int t) {
    insert(1, -LEN, LEN, dep[x] + v, t);
    for (int i = head[x]; i; i = e[i].next) {
        int y = e[i].y;
        if (y == fa) continue;
        modify(y, x, v, t);
    }
}

void dfs(int x, int fa) {
    modify(x, fa, -2 * dep[x], 1);
    ans[x] = query(1, -LEN, LEN, K) + dep[x];
    for (int i = head[x]; i; i = e[i].next) {
        int y = e[i].y;
        if (y == fa) continue;
        modify(y, x, -2 * dep[x], -1);
        dfs(y, x);
        modify(y, x, -2 * dep[x], 1);
    }
    modify(x, fa, -2 * dep[x], -1);
}

int main() {
    FILEIN("path.in"); FILEOUT("path.out");
    n = read(); K = read();
    for (int i = 1; i < n; ++i) {
        int x = read(), y = read(), w = read();
        connect(x, y, w); connect(y, x, w);
    }
    prework(1, 0);
    dfs(1, 0);
    for (int x = 1; x <= n; ++x) printf("%d\n", ans[x]);
    return 0;
}

posted @ 2021-07-01 19:08  蓝田日暖玉生烟  阅读(76)  评论(3编辑  收藏  举报