LightOJ-1348 Aladdin and the Return Journey
Aladdin and the Return Journey
树链剖分模板题
结点单点修改
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 3e4 + 10;
vector<int>gra[maxn];
int dep[maxn];
int siz[maxn];
int hson[maxn];
int fa[maxn];
int top[maxn];
int dfn[maxn];
int rnk[maxn];
int tr[maxn << 2];
int w[maxn];
void build(int now, int l, int r)
{
if(l == r)
{
tr[now] = w[rnk[l]];
return;
}
int mid = l + r >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
tr[now] = tr[now << 1] + tr[now << 1 | 1];
}
int query(int now, int l, int r, int L, int R)
{
if(L <= l && r <= R)
return tr[now];
int mid = l + r >> 1;
int ans = 0;
if(L <= mid)
ans += query(now << 1, l, mid, L, R);
if(R > mid)
ans += query(now << 1 | 1, mid + 1, r, L, R);
return ans;
}
void update(int now, int l, int r, int x, int val)
{
if(l == r)
{
tr[now] = val;
return;
}
int mid = l + r >> 1;
if(x <= mid)
update(now << 1, l, mid, x, val);
else
update(now << 1 | 1, mid + 1, r, x, val);
tr[now] = tr[now << 1] + tr[now << 1 | 1];
}
void dfs1(int now, int pre, int d)
{
siz[now] = 1;
hson[now] = -1;
dep[now] = d;
fa[now] = pre;
for(int i=0; i<gra[now].size(); i++)
{
int nex = gra[now][i];
if(nex == fa[now]) continue;
dfs1(nex, now, d + 1);
siz[now] += siz[nex];
if(hson[now] == -1 || siz[hson[now]] < siz[nex])
hson[now] = nex;
}
}
int tp = 0;
void dfs2(int now, int t)
{
top[now] = t;
tp++;
dfn[now] = tp;
rnk[tp] = now;
if(hson[now] != -1)
{
dfs2(hson[now], t);
for(int i=0; i<gra[now].size(); i++)
{
int nex = gra[now][i];
if(nex == fa[now] || nex == hson[now]) continue;
dfs2(nex, nex);
}
}
}
void init(int n, int rt = 1)
{
tp = 0;
dfs1(rt, rt, 1);
dfs2(rt, rt);
build(1, 1, n);
for(int i=0; i<=n; i++) gra[i].clear();
}
int solve(int a, int b, int n)
{
int ans = 0;
while(top[a] != top[b])
{
if(dep[top[a]] < dep[top[b]]) swap(a, b);
ans += query(1, 1, n, dfn[top[a]], dfn[a]);
a = fa[top[a]];
}
if(dep[a] > dep[b]) swap(a, b);
ans += query(1, 1, n, dfn[a], dfn[b]);
return ans;
}
int main()
{
int t;
scanf("%d", &t);
for(int casee=1; casee<=t; casee++)
{
printf("Case %d:\n", casee);
int n;
scanf("%d", &n);
for(int i=0; i<n; i++) scanf("%d", &w[i]);
for(int i=1; i<n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
gra[x].push_back(y);
gra[y].push_back(x);
}
init(n, 0);
int q;
scanf("%d", &q);
while(q--)
{
int x;
scanf("%d", &x);
if(x == 0)
{
int i, j;
scanf("%d%d", &i, &j);
printf("%d\n", solve(i, j, n));
}
else
{
int i, v;
scanf("%d%d", &i, &v);
update(1, 1, n, dfn[i], v);
}
}
}
return 0;
}