Loading

SPOJ-QTREE2 Query on a tree II

Query on a tree II

倍增 LCA 或 树链剖分

找路径上第 k 个结点

先找到 LCA,然后再根据深度判断第几个结点

#include <iostream>
#include <vector>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
const int maxn = 1e4 + 10;
pii fa[maxn][25];
int dep[maxn];
vector<pii>gra[maxn];

void dfs(int now, int pre, int d)
{
    dep[now] = d;
    for(int i=0; i<gra[now].size(); i++)
    {
        auto [nex, x] = gra[now][i];
        if(nex == pre) continue;
        fa[nex][0] = {now, x};
        dfs(nex, now, d + 1);
    }
}

void init(int n, int rt)
{
    fa[rt][0] = {rt, 0};
    dfs(rt, rt, 1);
    for(int i=1; i<=20; i++)
    {
        for(int j=1; j<=n; j++)
        {
            fa[j][i] = fa[fa[j][i-1].first][i-1];
            fa[j][i].second += fa[j][i-1].second;
        }
    }
    for(int i=0; i<=n; i++) gra[i].clear();
}

pii LCA(int a, int b)
{
    int ans = 0;
    if(dep[a] < dep[b]) swap(a, b);
    int dif = dep[a] - dep[b];
    for(int i=20; dif && i>=0; i--)
    {
        if(dif >= (1 << i))
        {
            ans += fa[a][i].second;
            a = fa[a][i].first;
            dif -= 1 << i;
        }
    }
    if(a == b) return {a, ans};
    for(int i=20; i>=0; i--)
    {
        if(fa[a][i].first != fa[b][i].first)
        {
            ans += fa[a][i].second + fa[b][i].second;
            a = fa[a][i].first;
            b = fa[b][i].first;
        }
    }
    ans += fa[a][0].second + fa[b][0].second;
    return {fa[a][0].first, ans};
}

int kth(int a, int b, int k)
{
    auto [p, v] = LCA(a, b);
    int mid = dep[a] - dep[p] + 1, s = a;
    if(k <= mid)
        k--;
    else
    {
        k = dep[a] + dep[b] - dep[p] - dep[p] + 1 - k;
        s = b;
    }
    for(int i=20; i>=0; i--)
    {
        if(k >= (1 << i))
        {
            k -= 1 << i;
            s = fa[s][i].first;
        }
    }
    return s;
}

char op[20];
int main()
{
    int t;
    scanf("%d", &t);
    while(t--)
    {
        int n;
        scanf("%d", &n);
        for(int i=1; i<n; i++)
        {
            int x, y, v;
            scanf("%d%d%d", &x, &y, &v);
            gra[x].push_back({y, v});
            gra[y].push_back({x, v});
        }
        init(n, 1);
        while(1)
        {
            scanf("%s", op);
            if(op[1] == 'I')
            {
                int a, b;
                scanf("%d%d", &a, &b);
                printf("%d\n", LCA(a, b).second);
            }
            else if(op[1] == 'T')
            {
                int a, b, k;
                scanf("%d%d%d", &a, &b, &k);
                printf("%d\n", kth(a, b, k));
            }
            else break;
        }
    }
    return 0;
}
posted @ 2022-07-07 21:59  dgsvygd  阅读(29)  评论(0编辑  收藏  举报