bzoj4381 [POI2015]Odwiedziny

bzoj4381 [POI2015]Odwiedziny

给定一棵带点权的树,每次询问在 \(u\)\(v\) 的路径上,每次走 \(k\) 步,如果最后不足 \(k\) 步就走到了 \(v\) ,则会一步走到 \(v\) ,求每次行走的经过的点的点权和

\(n\leq5\times10^4,\ a_i\leq10^4\)

根号分治


考虑根号分治,如果 \(k>\sqrt n\) ,每次暴力枚举走到的节点,反之,预处理所有点每次走 \(i(i\leq\sqrt n)\) 步直到超过根的深度所经过的点权和(与 \(i\) 到根的路径不同),询问时计算贡献。

对于 \(k>\sqrt n\) 的询问,需特判 \(u=v\)\(v=lca\) 的情况,并且需要快速查询一个点的 \(k\) 级祖先,用倍增即可,时间复杂度 \(O(\sqrt n\log n)\)

对于 \(k\leq\sqrt n\) 的询问,同样要特判 \(u=v\)\(v=lca\) 的情况,还需注意 \(lca\to v\) 的路径的贡献。预处理时间复杂度 \(O(n\sqrt n\log n)\) ,单次查询 \(O(\log n)\)

综上,时间复杂度 \(O(n\sqrt n\log n)\) ,貌似把倍增换成长链剖分可以做到 \(O(n\sqrt n)\)

代码(开 C++11)

#include <bits/stdc++.h>
using namespace std;

const int maxn = 5e4 + 10;
int n, bsz, a[maxn], dep[maxn], fa[16][maxn], sum[225][maxn];

vector <int> e[maxn];

int findlca(int u, int v) {
  if (dep[u] < dep[v]) swap(u, v);
  for (int i = 15; ~i; i--) {
    if (dep[u] - (1 << i) >= dep[v]) {
      u = fa[i][u];
    }
  }
  if (u == v) return u;
  for (int i = 15; ~i; i--) {
    if (fa[i][u] != fa[i][v]) {
      u = fa[i][u], v = fa[i][v];
    }
  }
  return fa[0][u];
}

int findanc(int u, int k) {
  for (int i = 0; i < 16; i++) {
    if (k >> i & 1) u = fa[i][u];
  }
  return u;
}

void dfs1(int u, int f) {
  fa[0][u] = f;
  dep[u] = dep[f] + 1;
  for (int i = 1; i < 16; i++) {
    fa[i][u] = fa[i - 1][fa[i - 1][u]];
  }
  for (int v : e[u]) {
    if (v != f) dfs1(v, u);
  }
}

void dfs2(int u, int f) {
  for (int i = 1; i <= bsz; i++) {
    sum[i][u] = sum[i][findanc(u, i)] + a[u];
  }
  for (int v : e[u]) {
    if (v != f) dfs2(v, u);
  }
}

int query1(int u, int v, int k) {
  if (u == v) return a[u];
  int lca = findlca(u, v), res = 0;
  int delta = dep[u] - dep[lca];
  int anc = findanc(u, delta - delta % k);
  res = sum[k][u] - sum[k][anc] + a[anc], u = anc;
  if (v == lca) return res;
  int s = dep[u] + dep[v] - dep[lca] - dep[lca];
  if (s > k && s % k) {
    res += a[v], v = findanc(v, s % k);
  }
  s = dep[u] + dep[v] - dep[lca] - dep[lca];
  int tmp = s > k ? findanc(v, s - k) : v;
  res += sum[k][v] - sum[k][tmp] + a[tmp];
  return res;
}

int query2(int u, int v, int k) {
  if (u == v) return a[u];
  int lca = findlca(u, v), res = a[u] + a[v];
  while (1) {
    int anc = findanc(u, k);
    if (dep[anc] < dep[lca] || (v == lca && dep[anc] == dep[v])) {
      break;
    }
    res += a[anc], u = anc;
  }
  v = findanc(v, (dep[u] + dep[v] - dep[lca] - dep[lca]) % k);
  while (1) {
    int anc = findanc(v, k);
    if (dep[anc] <= dep[lca]) {
      break;
    }
    res += a[anc], v = anc;
  }
  return res;
}

int main() {
  scanf("%d", &n);
  bsz = sqrt(n);
  for (int i = 1; i <= n; i++) {
    scanf("%d", a + i);
  }
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d %d", &u, &v);
    e[u].push_back(v), e[v].push_back(u);
  }
  dfs1(1, 0), dfs2(1, 0);
  static int step[maxn];
  for (int i = 1; i <= n; i++) {
    scanf("%d", step + i);
  }
  for (int i = 2; i <= n; i++) {
    int u = step[i - 1], v = step[i], k;
    scanf("%d", &k);
    printf("%d\n", k <= bsz ? query1(u, v, k) : query2(u, v, k));
  }
  return 0;
}
posted @ 2019-05-09 09:15  cnJuanzhang  阅读(271)  评论(0编辑  收藏  举报