[bzoj4765]普通计算姬——分块
Brief Description
给定一棵n个节点的带权树,节点编号为1到n,以root为根,设sum[p]表示以点p为根的这棵子树中所有节点的权
值和。支持下列两种操作:
1 给定两个整数u,v,修改点u的权值为v。
2 给定两个整数l,r,计算sum[l]+sum[l+1]+....+sum[r-1]+sum[r]
Algorithm Design
我们考察暴力算法:
对于查询,我们如果处理出所有的sum[i]就可以处理了。考虑到是树上的子树查询,我们考虑使用dfs序,使用BIT维护即可,这个算法是\(\Theta(log n)\)的。
对于修改,如果我们能够记录每个点是否可以被这个修改的点影响,我们就可以扫一遍就ok了。这样的复杂度是\(\Theta(nlogn)\)的。
我们可以看到修改的复杂度比较大,我们考虑使用分块平衡一下这两个算法,假设我们\(h(n)\)分一块,
对于查询,我们可以把查询区间分为\(\Theta(\frac{n}{h(n)})\)个区间,对于每个区间直接统计,对于两边的区间直接统计复杂度\(\Theta(h(n)log(h(n))\)
对于修改,我们统计每一个区域会被修改的点修改几次,直接统计影响即可。复杂度是\(\Theta(\frac{n}{h(n)})\)
这样,算法总的复杂度就是\(\Theta(h(n)log(h(n))+\frac{n}{h(n)})\),为了方便,我们设\(h(n) = \sqrt n\)。
Code
#include <cmath>
#include <cstdio>
#define ll unsigned long long
const ll maxn = 100010;
const ll maxm = 320;
ll value[maxn], bit[maxn << 1], wtf[maxn];
int f[maxn][maxm];
int t[maxm], l[maxn], r[maxn], head[maxn], b[maxn];
ll n, m, rt, ind = 0, cnt = 0, blockm, block;
struct edge {
ll to, next;
} e[maxn << 2];
ll lowbit(ll x) { return x & -x; }
void change(ll pos, ll x) {
for (ll i = pos; i <= n; i += lowbit(i))
bit[i] += x;
}
ll sum(ll pos) {
ll ans = 0;
for (ll i = pos; i > 0; i -= lowbit(i))
ans += bit[i];
return ans;
}
void insert(ll x, ll y) {
e[++cnt].to = y;
e[cnt].next = head[x];
head[x] = cnt;
e[++cnt].to = x;
e[cnt].next = head[y];
head[y] = cnt;
}
void dfs(ll x, ll fa) {
t[b[x]]++;
for (ll i = 1; i <= blockm; i++)
f[x][i] = t[i];
l[x] = ++ind;
change(l[x], value[x]);
for (ll i = head[x]; i; i = e[i].next) {
if (e[i].to != fa)
dfs(e[i].to, x);
}
r[x] = ind;
t[b[x]]--;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("input", "r", stdin);
#endif
scanf("%llu %llu", &n, &m);
block = (int)sqrt(n);
blockm = (n - 1) / block + 1;
for (ll i = 1; i <= n; i++) {
scanf("%llu", &value[i]);
}
for (ll i = 1; i <= n; i++)
b[i] = (i - 1) / block + 1;
for (ll i = 1; i <= n; i++) {
ll u, v;
scanf("%llu %llu", &u, &v);
if (u == 0)
rt = v;
else
insert(u, v);
}
dfs(rt, 0);
for (ll i = 1; i <= n; i++)
wtf[b[i]] += sum(r[i]) - sum(l[i] - 1);
for (ll i = 1; i <= m; i++) {
ll op, u, v;
scanf("%llu %llu %llu", &op, &u, &v);
if (op == 1) {
change(l[u], v - value[u]);
for (ll j = 1; j <= blockm; j++)
wtf[j] += 1ll * f[u][j] * (v - value[u]);
value[u] = v;
} else {
ll ans = 0;
if (b[u] == b[v])
for (ll j = u; j <= v; j++)
ans += sum(r[j]) - sum(l[j] - 1);
else {
for (ll j = b[u] + 1; j <= b[v] - 1; j++)
ans += wtf[j];
for (ll j = u; b[j] == b[u] && j <= n; j++)
ans += sum(r[j]) - sum(l[j] - 1);
for (ll j = v; b[j] == b[v] && j; j--)
ans += sum(r[j]) - sum(l[j] - 1);
}
printf("%llu\n", ans);
}
}
}