Loading

POJ-1986 Distance Queries

Distance Queries

树上距离板题

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 4e4 + 10;
typedef long long ll;
#define pii pair<int, ll>
vector<pii>gra[maxn];
int dep[maxn], fa[maxn][20];
ll sum[maxn];

void dfs(int now, int pre, int d)
{
    fa[now][0] = pre;
    dep[now] = d;
    sum[now] += sum[pre];
    for(int i=0; i<gra[now].size(); i++)
    {
        pii &a = gra[now][i];
        if(a.first == pre) continue;
        sum[a.first] += a.second;
        dfs(a.first, now, d + 1);
    }
}

void init(int n)
{
    dfs(1, 1, 1);
    for(int i=1; i<20; i++)
        for(int j=1; j<=n; j++)
            fa[j][i] = fa[fa[j][i-1]][i-1];
}

int LCA(int u, int v)
{
    if(dep[u] < dep[v]) swap(u, v);
    int dif = dep[u] - dep[v];
    for(int i=19; i>=0; i--)
    {
        if(dif >= (1 << i))
        {
            dif -= 1 << i;
            u = fa[u][i];
        }
    }
    if(u == v) return u;
    for(int i=19; i>=0; i--)
    {
        if(fa[u][i] != fa[v][i])
        {
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}

ll dis(int u, int v)
{
    int lca = LCA(u, v);
    return sum[u] - sum[lca] * 2 + sum[v];
}

int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i=0; i<m; i++)
    {
        int a, b;
        ll c;
        char op[3];
        scanf("%d%d%lld%s", &a, &b, &c, op);
        gra[a].push_back(make_pair(b, c));
        gra[b].push_back(make_pair(a, c));
    }
    init(n);
    int q;
    scanf("%d", &q);
    while(q--)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        printf("%lld\n", dis(a, b));
    }
    // cout << endl;
    return 0;
}
posted @ 2022-07-29 16:31  dgsvygd  阅读(23)  评论(0编辑  收藏  举报