【ybtoj高效进阶 21270】三只企鹅(树链剖分)(线段树)
三只企鹅
题目链接:ybtoj高效进阶 21270
题目大意
给你一棵树,然后要你支持一些操作。
给一个点的权值加一(一开始都是 0),计算所有点到一个点的距离乘各自点的权值。
思路
考虑把每个距离拆成 \(deg_x+deg_y-2deg_{lca}\)。
然后不难发现就第三项比较难搞。
考虑这么一种计算方法,在放点的时候,把它到根节点的路径上的边都加一,然后询问的时候它的根节点的路径的值的和就是第三项,感性理解即可看出你找到的就是 lca 到根的路径。
然后这个加的过程和查的过程可以用树链剖分和线段树实现。
代码
#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
struct node {
int x, to, nxt;
}e[400001];
int n, m, le[200001], KK, tmpp;
int x, y, op, z, deg[200001], dy[200001];
int fa[200001], son[200001], top[200001];
int sz[200001], dfn[200001], cnt;
ll lsum, tmp[200001], degtmp[200001];
void add(int x, int y, int z) {
e[++KK] = (node){z, y, le[x]}; le[x] = KK;
e[++KK] = (node){z, x, le[y]}; le[y] = KK;
}
void dfs1(int now, int father) {//树链剖分预处理
deg[now] = deg[father] + 1;
fa[now] = father;
sz[now] = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
tmp[e[i].to] = e[i].x;
degtmp[e[i].to] = degtmp[now] + tmp[e[i].to];
dfs1(e[i].to, now);
sz[now] += sz[e[i].to];
if (sz[e[i].to] > sz[son[now]]) son[now] = e[i].to;
}
}
void dfs2(int now, int father) {
dfn[now] = ++tmpp;
dy[tmpp] = now;
if (son[now]) {
top[son[now]] = top[now];
dfs2(son[now], now);
}
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && e[i].to != son[now]) {
top[e[i].to] = e[i].to;
dfs2(e[i].to, now);
}
}
struct XDtree {//线段树
ll a[800001], sum[800001];
ll lzy[800001];
void up(int now) {
a[now] = a[now << 1] + a[now << 1 | 1];
sum[now] = sum[now << 1] + sum[now << 1 | 1];
}
void down(int now) {
if (!lzy[now]) return ;
sum[now << 1] += a[now << 1] * lzy[now];
sum[now << 1 | 1] += a[now << 1 | 1] * lzy[now];
lzy[now << 1] += lzy[now];
lzy[now << 1 | 1] += lzy[now];
lzy[now] = 0;
}
void build(int now, int l, int r) {
if (l == r) {
a[now] = tmp[dy[l]];
return ;
}
int mid = (l + r) >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
up(now);
}
void insert(int now, int l, int r, int L, int R, ll t) {
if (L <= l && r <= R) {
sum[now] += a[now] * t;
lzy[now] += t;
return ;
}
down(now);
int mid = (l + r) >> 1;
if (L <= mid) insert(now << 1, l, mid, L, R, t);
if (mid < R) insert(now << 1 | 1, mid + 1, r, L, R, t);
up(now);
}
ll query(int now, int l, int r, int L, int R) {
if (L <= l && r <= R) {
return sum[now];
}
down(now);
int mid = (l + r) >> 1;
ll re = 0;
if (L <= mid) re += query(now << 1, l, mid, L, R);
if (mid < R) re += query(now << 1 | 1, mid + 1, r, L, R);
return re;
}
}T;
int main() {
// freopen("express.in", "r", stdin);
// freopen("express.out", "w", stdout);
scanf("%d %d", &n, &m);
for (int i = 1; i < n; i++) {
scanf("%d %d %d", &x, &y, &z);
add(x, y, z);
}
dfs1(1, 0);
top[1] = 1;
dfs2(1, 0);
T.build(1, 1, n);
while (m--) {
scanf("%d %d", &op, &x);
if (op == 1) {
lsum += degtmp[x]; cnt++;
while (x) {
T.insert(1, 1, n, dfn[top[x]], dfn[x], 1);
x = fa[top[x]];
}
}
if (op == 2) {
ll re = lsum + 1ll * cnt * degtmp[x];
while (x) {
re -= 2ll * T.query(1, 1, n, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}
printf("%lld\n", re);
}
}
return 0;
}