zoj3195(lca / RMQ在线)

题目链接: http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3195

 

题意: 给出一棵 n 个节点的带边权的树, 有 q 组形如 x, y, z 的询问, 输出 x, y, z之间的最短路径.

 

思路: 在纸上画下不难发现 x, y, z之间的最短路径就是 x, y, z 两两之间的最短路径和的一半.

我们可以通过 lca 模板求出 x, y, z 两两之间的最短路径, 然后再算下 x, y, z三点之间的最短路径即可.

这题应该是用 RMQ 在线比较好写一点, 用 tarjan 的话记录路径有点麻烦.

 

代码:

 1 #include <iostream>
 2 #include <stdio.h>
 3 #include <string.h>
 4 #include <math.h>
 5 using namespace std;
 6 
 7 const int MAXN = 5e4 + 10;
 8 struct node{
 9     int v, w, next;
10 }edge[MAXN << 1];
11 
12 int dp[MAXN << 1][30];
13 int ver[MAXN << 1], deep[MAXN << 1], first[MAXN];
14 int dis[MAXN], head[MAXN], vis[MAXN], indx, ip;
15 
16 inline void init(void){
17     memset(vis, 0, sizeof(vis));
18     memset(head, -1, sizeof(head));
19     indx = 0;
20     ip = 0;
21 }
22 
23 void addedge(int u, int v, int w){
24     edge[ip].v = v;
25     edge[ip].w = w;
26     edge[ip].next = head[u];
27     head[u] = ip++;
28 }
29 
30 void dfs(int u, int h){
31     vis[u] = 1;
32     ver[++indx] = u;
33     deep[indx] = h;
34     first[u] = indx;
35     for(int i = head[u]; i != -1; i = edge[i].next){
36         int v = edge[i].v;
37         if(!vis[v]){
38             dis[v] = dis[u] + edge[i].w;
39             dfs(v, h + 1);
40             ver[++indx] = u;
41             deep[indx] = h;
42         }
43     }
44 }
45 
46 void ST(int n){
47     for(int i = 1; i <= n; i++){
48         dp[i][0] = i;
49     }
50     for(int j = 1; (1 << j) <= n; j++){
51         for(int i = 1; i + (1 << j) - 1 <= n; i++){
52             int x = dp[i][j - 1], y = dp[i + (1 << (j - 1))][j - 1];
53             dp[i][j] = deep[x] < deep[y] ? x : y;
54         }
55     }
56 }
57 
58 int RMQ(int l, int r){
59     int len = log2(r - l + 1);
60     int x = dp[l][len], y = dp[r - (1 << len) + 1][len];
61     return deep[x] < deep[y] ? x : y;
62 }
63 
64 int LCA(int x, int y){
65     int l = first[x], r = first[y];
66     if(l > r) swap(l, r);
67     int pos = RMQ(l, r);
68     return ver[pos];
69 }
70 
71 int main(void){
72     bool flag = false;
73     int n, q, x, y, z;
74     while(~scanf("%d", &n)){
75         if(flag) puts("");
76         flag = true;
77         init();
78         for(int i = 1; i < n; i++){
79             scanf("%d%d%d", &x, &y, &z);
80             addedge(x, y, z);
81             addedge(y, x, z);
82         }
83         dis[1] = 0;
84         dfs(1, 1);
85         ST(2 * n - 1);
86         scanf("%d", &q);
87         while(q--){
88             scanf("%d%d%d", &x, &y, &z);
89             int lca1 = LCA(x, y);
90             int lca2 = LCA(x, z);
91             int lca3 = LCA(y, z);
92             int sol1 = dis[x] + dis[y] - 2 * dis[lca1];
93             int sol2 = dis[x] + dis[z] - 2 * dis[lca2];
94             int sol3 = dis[y] + dis[z] - 2 * dis[lca3];
95             printf("%d\n", (sol1 + sol2 + sol3) >> 1);
96         }
97     }
98     return 0;
99 }
View Code

 

posted @ 2017-07-20 11:02  geloutingyu  阅读(280)  评论(0编辑  收藏  举报