LCA + 树状数组 + 树上RMQ
题目链接:http://poj.org/problem?id=2763
思路:首先求出树上dfs序列,并且标记树上每个节点开始遍历以及最后回溯遍历到的时间戳,由于需要修改树上的某两个节点之间的权值,如果parent[v] = u, 那么说明修改之后的v的子树到当前根的距离都会改变,由于遍历到v时有开始时间戳以及结束时间戳,那么处于这个区间所有节点都会影响到,于是我们可以通过数组数组来更新某个区间的值,只需令从区间起始点之后的每个值都增加一个改变量了,令区间中止点之后的每个值都减小一个改变量,这样我们就可以得到处于该区间的值的改变量。至于如何求树上两个节点的LCA,这里可以使用RMQ,O(n * log(n))的预处理,O(1)的查询复杂度。\
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> using namespace std; const int MAX_N = (200000 + 10000); struct Edge { int u, v, w, next; Edge () {} Edge (int _u, int _v, int _w, int _next) : u(_u), v(_v), w(_w), next(_next) {} } EE[MAX_N << 2], edge[MAX_N << 2]; int NE, head[MAX_N]; void Init() { NE = 0; memset(head, -1, sizeof(head)); } void Insert(int u, int v, int w) { edge[NE].u = u; edge[NE].v = v; edge[NE].w = w; edge[NE].next = head[u]; head[u] = NE++; } int index, dfs_seq[MAX_N << 2]; //dfs序列 int First[MAX_N], End[MAX_N]; //First数组表示第一次遍历到的时间戳,End数组表示最后回溯时遍历到的时间戳 int parent[MAX_N], dep[MAX_N << 2]; //深度 int N, Q, st, cost[MAX_N]; void dfs(int u, int fa, int d, int c) { parent[u] = fa; First[u] = index; dfs_seq[index] = u; dep[index++] = d; cost[u] = c; for (int i = head[u]; ~i; i = edge[i].next) { int v = edge[i].v, w = edge[i].w; if (v == fa) continue; dfs(v, u, d + 1, cost[u] + w); dfs_seq[index] = u; dep[index++] = d; } End[u] = index; } int dp[MAX_N][40]; //dp[i][j]表示从时间戳i开始,长度为(1 << j)的区间中深度最小的点的时间戳 void Init_RMQ() { for (int i = 0; i < index; ++i) dp[i][0] = i; for (int j = 1; (1 << j) < index; ++j) { for (int i = 0; i + (1 << j) - 1 < index; ++i) { dp[i][j] = (dep[ dp[i][j - 1] ] < dep[ dp[i + (1 << (j - 1))][j - 1] ] ? dp[i][j - 1] : dp[i + (1 << (j - 1))][j - 1]); } } } int RMQ_Query(int l, int r) { int k = (int)(log(r * 1.0 - l + 1) / log(2.0)); return (dep[ dp[l][k] ] < dep[ dp[r - (1 << k) + 1][k] ] ? dp[l][k] : dp[r - (1 << k) + 1][k]); } int LCA(int x, int y) { if (x > y) swap(x, y); return dfs_seq[RMQ_Query(x, y)]; } int C[MAX_N << 2]; int lowbit(int x) { return x & (-x); } void update(int i, int val) { while (i < index) { C[i] += val; i += lowbit(i); } } int getSum(int i) { int sum = 0; while (i > 0) { sum += C[i]; i -= lowbit(i); } return sum; } int main() { while (~scanf("%d %d %d", &N, &Q, &st)) { Init(); for (int i = 1; i < N; ++i) { int u, v, w; scanf("%d %d %d", &u, &v, &w); Insert(u, v, w); Insert(v, u, w); EE[i] = Edge(u, v, w, -1); } index = 0; dfs(st, st, 0, 0); Init_RMQ(); memset(C, 0, sizeof(C)); for (int i = 1; i <= Q; ++i) { int opt; scanf("%d", &opt); if (opt == 0) { int ed, fa, x, y, z; scanf("%d", &ed); x = First[st], y = First[ed], z = First[fa = LCA(x, y)]; printf("%d\n", getSum(x + 1) + getSum(y + 1) - 2 * getSum(z + 1) + cost[st] + cost[ed] - 2 * cost[fa]); st = ed; } else { int id, w; scanf("%d %d", &id, &w); int u = EE[id].u, v = EE[id].v, tmp = w - EE[id].w; EE[id].w = w; if (parent[u] == v) swap(u, v); update(First[v] + 1, tmp); update(End[v] + 1, -tmp); } } } return 0; }