[lnsyoj2240/luoguP3591]ODW

题意

给定一棵 \(n\) 个节点的树和数列 \(a,b,c\),分别表示点权,移动序列和步长。在第 \(i\) 次移动中,将会从节点 \(b_i\) 移动到节点 \(b_{i+1}\),步长为 \(c_i\)。求移动时经过的所有点的点权之和。

赛时 0PTS

赛后

对于一条路径 \(x\to y\),我们将其拆成 \(x\to lca\to y\),这样,我们只需要分别计算 \(lca\to x\)\(lca\to y\) 这两条路径。
考虑一种暴力的算法:每次 \(O(\log n)\) 向上跳 \(c_i\) 步,并累加答案。显然,当 \(c_i\) 较大时,这种做法可行,但较小时就会超过时限。
另一种可行的算法是:预处理出从第 \(i\) 个节点向上每次跳 \(j\) 个节点,一直到根所经过的所有点的点权之和,即\(g_{i,j} = g_{up(i,j), j}\),然后利用类似差分的方法,\(O(1)\) 求出和。但是这样做 \(c\) 较大时,会超过空间(甚至无法通过编译)
因此,我们设计出了满足两种情况的算法,因此我们考虑根号分治,设计阈值 \(S=\sqrt{n}\),若 \(c>S\),则使用暴力算法;若 \(c<S\),则使用预处理算法。

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>

using namespace std;

const int N = 50005, M = 50005, K = 16, SMAX = 230;

int h[N], e[M], ne[M], idx;
int a[N], b[N], c[N];
int n;
int f[N][K], depth[N];
int g[N][SMAX];
int S;

void add(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void dfs_init(int u, int fa){
    for (int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if (j == fa) continue;
        depth[j] = depth[u] + 1;
        f[j][0] = u;
        dfs_init(j, u);
    }
}

int up(int u, int k){
    for (int i = 0; i < K; i ++ ) {
        if (k & 1) u = f[u][i];
        k >>= 1;
    }
    return u;
}

void dfs_init2(int u, int fa){
    for (int k = 1; k <= S; k ++ )
        g[u][k] = g[up(u, k)][k] + a[u];

    for (int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if (j == fa) continue;
        dfs_init2(j, u);
    }
}

void init(){
    depth[1] = 1;
    dfs_init(1, -1);
    S = sqrt(n);

    for (int k = 1; k < K; k ++ )
        for (int i = 1; i <= n; i ++ )
            f[i][k] = f[f[i][k - 1]][k - 1];

    dfs_init2(1, -1);
}

int lca(int u, int v){
    if (depth[u] < depth[v]) swap(u, v);
    
    for (int i = K - 1; i >= 0; i -- ) {
        int ne = f[u][i];
        if (depth[ne] >= depth[v]) u = ne;
    }

    if (u == v) return u;

    for (int i = K - 1; i >= 0; i -- ) {
        int nu = f[u][i], nv = f[v][i];
        if (nu != nv) u = nu, v = nv;
    }

    return f[u][0];
}


void solve1(int st, int ed, int step){
    int r = lca(st, ed);
    int ans = 0;

    int u = st;
    while (depth[u] >= depth[r]){
        ans += a[u];
        u = up(u, step);
    }

    u = ed;
    while (depth[u] > depth[r]){
        ans += a[u];
        u = up(u, step);
    }

    printf("%d\n", ans);
}

void solve2(int st, int ed, int step){
    int r = lca(st, ed);
    int ans = 0;

    bool flag = true;

    int ueddist = step - (depth[st] - depth[r]) % step;
    if (ueddist != step) flag = false;
    int ued = up(r, ueddist);
    ans += g[st][step] - g[ued][step];

    ueddist = step - (depth[ed] - depth[r]) % step;
    if (ueddist != step) flag = false;
    ued = up(r, ueddist);
    ans += g[ed][step] - g[ued][step];

    if (flag) ans -= a[r];

    printf("%d\n", ans);
}

int main(){
    memset(h, -1, sizeof h);
    scanf("%d", &n);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
    for (int i = 1; i < n; i ++ ){
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v), add(v, u);
    }

    for (int i = 1; i <= n; i ++ ) scanf("%d", &b[i]);
    for (int i = 1; i < n; i ++ ) scanf("%d", &c[i]);
    
    init();

    for (int i = 1; i < n; i ++ ){
        int st = b[i], ed = b[i + 1], step = c[i];
        if (step > S) solve1(st, ed, step);
        else solve2(st, ed, step);
    }

    return 0;
}
posted @ 2024-08-07 11:09  是一只小蒟蒻呀  阅读(12)  评论(0编辑  收藏  举报