Loading

洛谷-P3178 树上操作

树上操作

树链剖分模板 - 子树区间加和

考虑到树链剖分的时候,一颗子树内的 dfn 序一定是连续的一段区间,因此只要找到子树内最大的 dfn 序即可,也就是树链剖分 dfs 的时候回到当前结点时,记录一下当前分配 dfn 序分配到了哪个值

然后直接线段树区间加和即可

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
vector<int>gra[maxn];
int dep[maxn], siz[maxn], hson[maxn], fa[maxn];
int dfn[maxn], bot[maxn], rnk[maxn], top[maxn];
ll tr[maxn << 2], lazy[maxn << 2], w[maxn];

void dfs1(int now, int pre, int d)
{
    dep[now] = d;
    siz[now] = 1;
    hson[now] = -1;
    fa[now] = pre;
    for(auto nex : gra[now])
    {
        if(nex == pre) continue;
        dfs1(nex, now, d + 1);
        siz[now] += siz[nex];
        if(hson[now] == -1 || siz[hson[now]] < siz[nex])
            hson[now] = nex;
    }
}

int tp = 0;
void dfs2(int now, int t)
{
    tp++;
    dfn[now] = tp;
    rnk[tp] = now;
    top[now] = t;
    if(hson[now] != -1)
    {
        dfs2(hson[now], t);
        for(auto nex : gra[now])
        {
            if(nex == fa[now] || nex == hson[now]) continue;
            dfs2(nex, nex);
        }
    }
    bot[now] = tp;
}

void build(int now, int l, int r)
{
    lazy[now] = 0;
    if(l == r)
    {
        tr[now] = w[rnk[l]];
        return;
    }
    int mid = l + r >> 1;
    build(now << 1, l, mid);
    build(now << 1 | 1, mid + 1, r);
    tr[now] = tr[now << 1] + tr[now << 1 | 1];
}

void init(int n, int rt = 1)
{
    tp = 0;
    dfs1(rt, rt, 1);
    dfs2(rt, rt);
    build(1, 1, n);
    for(int i=0; i<=n; i++) gra[i].clear();
}

inline void push_down(int now, int l, int r)
{
    if(lazy[now])
    {
        int lson = now << 1, rson = now << 1 | 1, mid = l + r >> 1;
        tr[lson] += lazy[now] * (mid - l + 1);
        tr[rson] += lazy[now] * (r - mid);
        lazy[lson] += lazy[now];
        lazy[rson] += lazy[now];
        lazy[now] = 0;
    }
}

void update(int now, int l, int r, int L, int R, ll val)
{
    if(L <= l && r <= R)
    {
        tr[now] += val * (r - l + 1);
        lazy[now] += val;
        return;
    }
    push_down(now, l, r);
    int mid = l + r >> 1;
    if(L <= mid)
        update(now << 1, l, mid, L, R, val);
    if(R > mid)
        update(now << 1 | 1, mid + 1, r, L, R, val);
    tr[now] = tr[now << 1] + tr[now << 1 | 1];
}

ll query(int now, int l, int r, int L, int R)
{
    if(L <= l && r <= R)
        return tr[now];
    push_down(now, l, r);
    int mid = l + r >> 1;
    ll ans = 0;
    if(L <= mid)
        ans += query(now << 1, l, mid, L, R);
    if(R > mid)
        ans += query(now << 1 | 1, mid + 1, r, L, R);
    return ans;
}

ll solve(int n, int a, int b)
{
    ll ans = 0;
    while(top[b] != 1)
    {
        ans += query(1, 1, n, dfn[top[b]], dfn[b]);
        b = fa[top[b]];
    }
    ans += query(1, 1, n, 1, dfn[b]);
    return ans;
}

int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i=1; i<=n; i++) scanf("%lld", &w[i]);
    for(int i=1; i<n; i++)
    {
        int x, y;
        scanf("%d%d", &x, &y);
        gra[x].push_back(y);
        gra[y].push_back(x);
    }
    init(n);
    while(m--)
    {
        int t;
        scanf("%d", &t);
        if(t ==  1)
        {
            int x, a;
            scanf("%d%d", &x, &a);
            update(1, 1, n, dfn[x], dfn[x], a);
        }
        else if(t == 2)
        {
            int x, a;
            scanf("%d%d", &x, &a);
            update(1, 1, n, dfn[x], bot[x], a);
        }
        else
        {
            int x;
            scanf("%d", &x);
            printf("%lld\n", solve(n, 1, x));
        }
    }
    return 0;
}
posted @ 2022-07-10 01:04  dgsvygd  阅读(28)  评论(0编辑  收藏  举报