[lnsyoj2240/luoguP3591]ODW
题意
给定一棵 \(n\) 个节点的树和数列 \(a,b,c\),分别表示点权,移动序列和步长。在第 \(i\) 次移动中,将会从节点 \(b_i\) 移动到节点 \(b_{i+1}\),步长为 \(c_i\)。求移动时经过的所有点的点权之和。
赛时 0PTS
赛后
对于一条路径 \(x\to y\),我们将其拆成 \(x\to lca\to y\),这样,我们只需要分别计算 \(lca\to x\),\(lca\to y\) 这两条路径。
考虑一种暴力的算法:每次 \(O(\log n)\) 向上跳 \(c_i\) 步,并累加答案。显然,当 \(c_i\) 较大时,这种做法可行,但较小时就会超过时限。
另一种可行的算法是:预处理出从第 \(i\) 个节点向上每次跳 \(j\) 个节点,一直到根所经过的所有点的点权之和,即\(g_{i,j} = g_{up(i,j), j}\),然后利用类似差分的方法,\(O(1)\) 求出和。但是这样做 \(c\) 较大时,会超过空间(甚至无法通过编译)
因此,我们设计出了满足两种情况的算法,因此我们考虑根号分治,设计阈值 \(S=\sqrt{n}\),若 \(c>S\),则使用暴力算法;若 \(c<S\),则使用预处理算法。
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;
const int N = 50005, M = 50005, K = 16, SMAX = 230;
int h[N], e[M], ne[M], idx;
int a[N], b[N], c[N];
int n;
int f[N][K], depth[N];
int g[N][SMAX];
int S;
void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void dfs_init(int u, int fa){
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == fa) continue;
depth[j] = depth[u] + 1;
f[j][0] = u;
dfs_init(j, u);
}
}
int up(int u, int k){
for (int i = 0; i < K; i ++ ) {
if (k & 1) u = f[u][i];
k >>= 1;
}
return u;
}
void dfs_init2(int u, int fa){
for (int k = 1; k <= S; k ++ )
g[u][k] = g[up(u, k)][k] + a[u];
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == fa) continue;
dfs_init2(j, u);
}
}
void init(){
depth[1] = 1;
dfs_init(1, -1);
S = sqrt(n);
for (int k = 1; k < K; k ++ )
for (int i = 1; i <= n; i ++ )
f[i][k] = f[f[i][k - 1]][k - 1];
dfs_init2(1, -1);
}
int lca(int u, int v){
if (depth[u] < depth[v]) swap(u, v);
for (int i = K - 1; i >= 0; i -- ) {
int ne = f[u][i];
if (depth[ne] >= depth[v]) u = ne;
}
if (u == v) return u;
for (int i = K - 1; i >= 0; i -- ) {
int nu = f[u][i], nv = f[v][i];
if (nu != nv) u = nu, v = nv;
}
return f[u][0];
}
void solve1(int st, int ed, int step){
int r = lca(st, ed);
int ans = 0;
int u = st;
while (depth[u] >= depth[r]){
ans += a[u];
u = up(u, step);
}
u = ed;
while (depth[u] > depth[r]){
ans += a[u];
u = up(u, step);
}
printf("%d\n", ans);
}
void solve2(int st, int ed, int step){
int r = lca(st, ed);
int ans = 0;
bool flag = true;
int ueddist = step - (depth[st] - depth[r]) % step;
if (ueddist != step) flag = false;
int ued = up(r, ueddist);
ans += g[st][step] - g[ued][step];
ueddist = step - (depth[ed] - depth[r]) % step;
if (ueddist != step) flag = false;
ued = up(r, ueddist);
ans += g[ed][step] - g[ued][step];
if (flag) ans -= a[r];
printf("%d\n", ans);
}
int main(){
memset(h, -1, sizeof h);
scanf("%d", &n);
for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
for (int i = 1; i < n; i ++ ){
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
for (int i = 1; i <= n; i ++ ) scanf("%d", &b[i]);
for (int i = 1; i < n; i ++ ) scanf("%d", &c[i]);
init();
for (int i = 1; i < n; i ++ ){
int st = b[i], ed = b[i + 1], step = c[i];
if (step > S) solve1(st, ed, step);
else solve2(st, ed, step);
}
return 0;
}