POJ-1986 Distance Queries
Distance Queries
树上距离板题
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 4e4 + 10;
typedef long long ll;
#define pii pair<int, ll>
vector<pii>gra[maxn];
int dep[maxn], fa[maxn][20];
ll sum[maxn];
void dfs(int now, int pre, int d)
{
fa[now][0] = pre;
dep[now] = d;
sum[now] += sum[pre];
for(int i=0; i<gra[now].size(); i++)
{
pii &a = gra[now][i];
if(a.first == pre) continue;
sum[a.first] += a.second;
dfs(a.first, now, d + 1);
}
}
void init(int n)
{
dfs(1, 1, 1);
for(int i=1; i<20; i++)
for(int j=1; j<=n; j++)
fa[j][i] = fa[fa[j][i-1]][i-1];
}
int LCA(int u, int v)
{
if(dep[u] < dep[v]) swap(u, v);
int dif = dep[u] - dep[v];
for(int i=19; i>=0; i--)
{
if(dif >= (1 << i))
{
dif -= 1 << i;
u = fa[u][i];
}
}
if(u == v) return u;
for(int i=19; i>=0; i--)
{
if(fa[u][i] != fa[v][i])
{
u = fa[u][i];
v = fa[v][i];
}
}
return fa[u][0];
}
ll dis(int u, int v)
{
int lca = LCA(u, v);
return sum[u] - sum[lca] * 2 + sum[v];
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for(int i=0; i<m; i++)
{
int a, b;
ll c;
char op[3];
scanf("%d%d%lld%s", &a, &b, &c, op);
gra[a].push_back(make_pair(b, c));
gra[b].push_back(make_pair(a, c));
}
init(n);
int q;
scanf("%d", &q);
while(q--)
{
int a, b;
scanf("%d%d", &a, &b);
printf("%lld\n", dis(a, b));
}
// cout << endl;
return 0;
}