hdu2586 How far away ?(lca模版题)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2586

题意:给出一棵树还有两个点然后求这两个点的最短距离。

 

题解:val[a]+val[b]-2*val[root]就是这两个点到根节点的距离再减去它们最近的公共父节点到根节点的距离的两倍

然后就是利用lca来求最近公共父节点。由于是模版题代码上写了一点注释方便理解。

 

#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
using namespace std;
const int M = 4e4 + 10;
vector<pair<int , int>>vc[M];
int p[M][20] , deep[M] , val[M];//p[i][j]表示距离i点2的j次的父节点是什么。deep表示深度
void dfs(int pos , int pre , int dep) {
    int len = vc[pos].size();
    deep[pos] = dep;
    p[pos][0] = pre;
    for(int i = 0 ; i < len ; i++) {
        int u = vc[pos][i].first;
        if(u != pre) {
            val[u] += (val[pos] + vc[pos][i].second);
            dfs(u , pos , dep + 1);
        }
    }
}//dfs一遍记录一下每一个点的深度顺便记录一下当前点到父节点的距离
void init(int n) {
    for(int i = 0 ; i < 16 ; i++) {
        for(int j = 1 ; j <= n ; j++) {
            if(p[j][i] == -1) {
                p[j][i + 1] = -1;//如果是根节点之后距离不论为多少都为-1
            }
            else {
                p[j][i + 1] = p[p[j][i]][i];//跟新p
            }
        }
    }
}//初始化p
int lca(int a , int b) {
    if(deep[a] < deep[b]) {
        swap(a , b);
    }
    int d = deep[a] - deep[b];
    for(int i = 0 ; i < 16 ; i++) {
        if(d & (1 << i)) {
            a = p[a][i];
        }//这里很巧妙好好理解一下
    }//使得a,b两点在同一深度
    if(a == b) {
        return a;
    }
    for(int i = 16; i >= 0 ; i--) {
        if(p[a][i] != p[b][i] && p[a][i] != -1) {
            a = p[a][i];
            b = p[b][i];
        }//如果不是共同根那么继续向上遍历。
    }
    return p[a][0];
}
int main() {
    int t , n , m , u , v , w;
    scanf("%d" , &t);
    while(t--) {
        scanf("%d%d" , &n , &m);
        for(int i = 1 ; i <= n ; i++) {
            vc[i].clear();
            val[i] = 0;
        }
        memset(p , -1 , sizeof(p));
        for(int i = 1 ; i < n ; i++) {
            scanf("%d%d%d" , &u , &v , &w);
            vc[u].push_back(make_pair(v , w));
            vc[v].push_back(make_pair(u , w));
        }
        dfs(1 , -1 , 1);
        init(n);
        for(int i = 0 ; i < m ; i++) {
            scanf("%d%d" , &u , &v);
            int pos = lca(u , v);
            printf("%d\n" , val[u] + val[v] - 2 * val[pos]);
        }
    }
    return 0;
}

posted @ 2017-04-20 10:25  Gealo  阅读(143)  评论(0编辑  收藏  举报