bzoj4381 [POI2015]Odwiedziny
给定一棵带点权的树,每次询问在 \(u\) 到 \(v\) 的路径上,每次走 \(k\) 步,如果最后不足 \(k\) 步就走到了 \(v\) ,则会一步走到 \(v\) ,求每次行走的经过的点的点权和
\(n\leq5\times10^4,\ a_i\leq10^4\)
根号分治
考虑根号分治,如果 \(k>\sqrt n\) ,每次暴力枚举走到的节点,反之,预处理所有点每次走 \(i(i\leq\sqrt n)\) 步直到超过根的深度所经过的点权和(与 \(i\) 到根的路径不同),询问时计算贡献。
对于 \(k>\sqrt n\) 的询问,需特判 \(u=v\) 与 \(v=lca\) 的情况,并且需要快速查询一个点的 \(k\) 级祖先,用倍增即可,时间复杂度 \(O(\sqrt n\log n)\)
对于 \(k\leq\sqrt n\) 的询问,同样要特判 \(u=v\) 与 \(v=lca\) 的情况,还需注意 \(lca\to v\) 的路径的贡献。预处理时间复杂度 \(O(n\sqrt n\log n)\) ,单次查询 \(O(\log n)\)
综上,时间复杂度 \(O(n\sqrt n\log n)\) ,貌似把倍增换成长链剖分可以做到 \(O(n\sqrt n)\) ?
代码(开 C++11)
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e4 + 10;
int n, bsz, a[maxn], dep[maxn], fa[16][maxn], sum[225][maxn];
vector <int> e[maxn];
int findlca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
for (int i = 15; ~i; i--) {
if (dep[u] - (1 << i) >= dep[v]) {
u = fa[i][u];
}
}
if (u == v) return u;
for (int i = 15; ~i; i--) {
if (fa[i][u] != fa[i][v]) {
u = fa[i][u], v = fa[i][v];
}
}
return fa[0][u];
}
int findanc(int u, int k) {
for (int i = 0; i < 16; i++) {
if (k >> i & 1) u = fa[i][u];
}
return u;
}
void dfs1(int u, int f) {
fa[0][u] = f;
dep[u] = dep[f] + 1;
for (int i = 1; i < 16; i++) {
fa[i][u] = fa[i - 1][fa[i - 1][u]];
}
for (int v : e[u]) {
if (v != f) dfs1(v, u);
}
}
void dfs2(int u, int f) {
for (int i = 1; i <= bsz; i++) {
sum[i][u] = sum[i][findanc(u, i)] + a[u];
}
for (int v : e[u]) {
if (v != f) dfs2(v, u);
}
}
int query1(int u, int v, int k) {
if (u == v) return a[u];
int lca = findlca(u, v), res = 0;
int delta = dep[u] - dep[lca];
int anc = findanc(u, delta - delta % k);
res = sum[k][u] - sum[k][anc] + a[anc], u = anc;
if (v == lca) return res;
int s = dep[u] + dep[v] - dep[lca] - dep[lca];
if (s > k && s % k) {
res += a[v], v = findanc(v, s % k);
}
s = dep[u] + dep[v] - dep[lca] - dep[lca];
int tmp = s > k ? findanc(v, s - k) : v;
res += sum[k][v] - sum[k][tmp] + a[tmp];
return res;
}
int query2(int u, int v, int k) {
if (u == v) return a[u];
int lca = findlca(u, v), res = a[u] + a[v];
while (1) {
int anc = findanc(u, k);
if (dep[anc] < dep[lca] || (v == lca && dep[anc] == dep[v])) {
break;
}
res += a[anc], u = anc;
}
v = findanc(v, (dep[u] + dep[v] - dep[lca] - dep[lca]) % k);
while (1) {
int anc = findanc(v, k);
if (dep[anc] <= dep[lca]) {
break;
}
res += a[anc], v = anc;
}
return res;
}
int main() {
scanf("%d", &n);
bsz = sqrt(n);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
}
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
e[u].push_back(v), e[v].push_back(u);
}
dfs1(1, 0), dfs2(1, 0);
static int step[maxn];
for (int i = 1; i <= n; i++) {
scanf("%d", step + i);
}
for (int i = 2; i <= n; i++) {
int u = step[i - 1], v = step[i], k;
scanf("%d", &k);
printf("%d\n", k <= bsz ? query1(u, v, k) : query2(u, v, k));
}
return 0;
}