洛谷-P3178 树上操作
树上操作
树链剖分模板 - 子树区间加和
考虑到树链剖分的时候,一颗子树内的 dfn 序一定是连续的一段区间,因此只要找到子树内最大的 dfn 序即可,也就是树链剖分 dfs 的时候回到当前结点时,记录一下当前分配 dfn 序分配到了哪个值
然后直接线段树区间加和即可
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
vector<int>gra[maxn];
int dep[maxn], siz[maxn], hson[maxn], fa[maxn];
int dfn[maxn], bot[maxn], rnk[maxn], top[maxn];
ll tr[maxn << 2], lazy[maxn << 2], w[maxn];
void dfs1(int now, int pre, int d)
{
dep[now] = d;
siz[now] = 1;
hson[now] = -1;
fa[now] = pre;
for(auto nex : gra[now])
{
if(nex == pre) continue;
dfs1(nex, now, d + 1);
siz[now] += siz[nex];
if(hson[now] == -1 || siz[hson[now]] < siz[nex])
hson[now] = nex;
}
}
int tp = 0;
void dfs2(int now, int t)
{
tp++;
dfn[now] = tp;
rnk[tp] = now;
top[now] = t;
if(hson[now] != -1)
{
dfs2(hson[now], t);
for(auto nex : gra[now])
{
if(nex == fa[now] || nex == hson[now]) continue;
dfs2(nex, nex);
}
}
bot[now] = tp;
}
void build(int now, int l, int r)
{
lazy[now] = 0;
if(l == r)
{
tr[now] = w[rnk[l]];
return;
}
int mid = l + r >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
tr[now] = tr[now << 1] + tr[now << 1 | 1];
}
void init(int n, int rt = 1)
{
tp = 0;
dfs1(rt, rt, 1);
dfs2(rt, rt);
build(1, 1, n);
for(int i=0; i<=n; i++) gra[i].clear();
}
inline void push_down(int now, int l, int r)
{
if(lazy[now])
{
int lson = now << 1, rson = now << 1 | 1, mid = l + r >> 1;
tr[lson] += lazy[now] * (mid - l + 1);
tr[rson] += lazy[now] * (r - mid);
lazy[lson] += lazy[now];
lazy[rson] += lazy[now];
lazy[now] = 0;
}
}
void update(int now, int l, int r, int L, int R, ll val)
{
if(L <= l && r <= R)
{
tr[now] += val * (r - l + 1);
lazy[now] += val;
return;
}
push_down(now, l, r);
int mid = l + r >> 1;
if(L <= mid)
update(now << 1, l, mid, L, R, val);
if(R > mid)
update(now << 1 | 1, mid + 1, r, L, R, val);
tr[now] = tr[now << 1] + tr[now << 1 | 1];
}
ll query(int now, int l, int r, int L, int R)
{
if(L <= l && r <= R)
return tr[now];
push_down(now, l, r);
int mid = l + r >> 1;
ll ans = 0;
if(L <= mid)
ans += query(now << 1, l, mid, L, R);
if(R > mid)
ans += query(now << 1 | 1, mid + 1, r, L, R);
return ans;
}
ll solve(int n, int a, int b)
{
ll ans = 0;
while(top[b] != 1)
{
ans += query(1, 1, n, dfn[top[b]], dfn[b]);
b = fa[top[b]];
}
ans += query(1, 1, n, 1, dfn[b]);
return ans;
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for(int i=1; i<=n; i++) scanf("%lld", &w[i]);
for(int i=1; i<n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
gra[x].push_back(y);
gra[y].push_back(x);
}
init(n);
while(m--)
{
int t;
scanf("%d", &t);
if(t == 1)
{
int x, a;
scanf("%d%d", &x, &a);
update(1, 1, n, dfn[x], dfn[x], a);
}
else if(t == 2)
{
int x, a;
scanf("%d%d", &x, &a);
update(1, 1, n, dfn[x], bot[x], a);
}
else
{
int x;
scanf("%d", &x);
printf("%lld\n", solve(n, 1, x));
}
}
return 0;
}