模板-树链剖分

因为要做bzoj1036这个题,所以学习一下树链剖分。

/**************************************************************
    Problem: 1036
    User: MrMorning
    Language: C++
    Result: Accepted
    Time:2868 ms
    Memory:7332 kb
****************************************************************/
 
#include <bits/stdc++.h>
using namespace std;
const int maxn = 30006;
//=====================
vector<int> G[maxn];
vector<int> son[maxn];
int weigh[maxn];
int size[maxn];
int vis[maxn];
int depth[maxn];
int fa[maxn];
int next[maxn];
int top[maxn];
int id[maxn];
void build_tree(int root);
void build(int k, int l, int r);
void dfs(int root, int);
void change(int x, int y);
int querysum(int k, int x, int y);
int querymx(int k, int x, int y);
int path_sum(int u, int v);
int path_max(int u, int v);
void solve();
int p, x;
int n;
int nz = 0;
int value[maxn];
struct dat {
  int l, r, sum, mx;
} seg[100005];
//=================
int main() {
  // freopen("tree.in", "r", stdin);
  //  freopen("tree.out", "w", stdout);
 
  scanf("%d", &n);
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d %d", &u, &v);
    G[u].push_back(v);
    G[v].push_back(u);
  }
  for (int i = 1; i <= n; i++)
    cin >> weigh[i];
  solve();
  return 0;
}
void build(int k, int l, int r) {
  seg[k].l = l;
  seg[k].r = r;
  if (l == r)
    return;
  int mid = (l + r) >> 1;
  build(k << 1, l, mid);
  build(k << 1 | 1, mid + 1, r);
  return;
}
void change(int k, int x, int y) {
  int l = seg[k].l, r = seg[k].r, mid = (l + r) >> 1;
  if (l == r) {
    seg[k].sum = seg[k].mx = y;
    return;
  }
  if (x <= mid)
    change(k << 1, x, y);
  else
    change(k << 1 | 1, x, y);
  seg[k].sum = seg[k << 1].sum + seg[k << 1 | 1].sum;
  seg[k].mx = max(seg[k << 1].mx, seg[k << 1 | 1].mx);
}
void solve() {
  build_tree(1);
  dfs(1, 1);
  build(1, 1, n);
  for (int i = 1; i <= n; i++) {
    change(1, id[i], weigh[i]);
  }
  int t;
  cin >> t;
  while (t--) {
    int p, u, v;
    char command[10];
    scanf("%s", command);
    scanf("%d %d", &u, &v);
    if (command[1] == 'M')
      printf("%d\n", path_max(u, v));
    else if (command[1] == 'S')
      printf("%d\n", path_sum(u, v));
    else
      change(1, id[u], v);
  }
}
void build_tree(int root) { // dfs1
  vector<int>::iterator it;
  size[root] = 1;
  vis[root] = 1;
  int Max = -1;
  int ans = -1;
  for (it = G[root].begin(); it != G[root].end(); it++)
    if (!vis[(*it)]) {
      int &v = *it;
      depth[v] = depth[root] + 1;
      fa[v] = root;
      build_tree(v);
      size[root] += size[v];
      son[root].push_back(v);
      if (size[v] > Max) {
        Max = size[v];
        ans = v;
      }
    }
  next[root] = ans;
}
void dfs(int root, int chain) { // dfs2
  id[root] = ++nz;
  top[root] = chain;
  if (son[root].empty())
    return;
  dfs(next[root], chain);
  std::vector<int>::iterator it;
  for (it = son[root].begin(); it != son[root].end(); it++) {
    int &v = *it;
    if (v != next[root]) {
      dfs(v, v);
    }
  }
  return;
}
int query_max(int num, int ql, int qr) {
  int L = seg[num].l;
  int R = seg[num].r;
  int mid = (L + R) >> 1;
  if (L == ql && R == qr)
    return seg[num].mx;
  else if (qr <= mid)
    return query_max(num << 1, ql, qr);
  else if (ql > mid)
    return query_max(num << 1 | 1, ql, qr);
  else
    return max(query_max(num << 1, ql, mid),
               query_max(num << 1 | 1, mid + 1, qr));
}
int query_sum(int num, int ql, int qr) {
  int L = seg[num].l;
  int R = seg[num].r;
  int mid = (L + R) >> 1;
  if (L == ql && R == qr)
    return seg[num].sum;
  else if (qr <= mid)
    return query_sum(num << 1, ql, qr);
  else if (ql > mid)
    return query_sum(num << 1 | 1, ql, qr);
  else
    return query_sum(num << 1, ql, mid) + query_sum(num << 1 | 1, mid + 1, qr);
}
int path_max(int x, int y) {
  int mx = -0x3f3f3f;
  while (top[x] != top[y]) {
    if (depth[top[x]] < depth[top[y]])
      swap(x, y);
    mx = max(mx, query_max(1, id[top[x]], id[x]));
    x = fa[top[x]];
  }
  if (id[x] > id[y])
    swap(x, y);
  mx = max(mx, query_max(1, id[x], id[y]));
  return mx;
}
int path_sum(int u, int v) {
  int sum = 0;
  while (top[u] != top[v]) {
    if (depth[top[u]] < depth[top[v]])
      swap(u, v);
    sum += query_sum(1, id[top[u]], id[u]);
    u = fa[top[u]];
  }
  if (id[u] > id[v])
    swap(u, v);
  sum += query_sum(1, id[u], id[v]);
  return sum;
}

posted on 2016-12-02 18:08  蒟蒻konjac  阅读(131)  评论(0编辑  收藏  举报