【ybtoj高效进阶 21270】三只企鹅(树链剖分)(线段树)

三只企鹅

题目链接:ybtoj高效进阶 21270

题目大意

给你一棵树,然后要你支持一些操作。
给一个点的权值加一(一开始都是 0),计算所有点到一个点的距离乘各自点的权值。

思路

考虑把每个距离拆成 \(deg_x+deg_y-2deg_{lca}\)

然后不难发现就第三项比较难搞。
考虑这么一种计算方法,在放点的时候,把它到根节点的路径上的边都加一,然后询问的时候它的根节点的路径的值的和就是第三项,感性理解即可看出你找到的就是 lca 到根的路径。

然后这个加的过程和查的过程可以用树链剖分和线段树实现。

代码

#include<cstdio>
#include<iostream>
#define ll long long

using namespace std;

struct node {
	int x, to, nxt;
}e[400001];
int n, m, le[200001], KK, tmpp;
int x, y, op, z, deg[200001], dy[200001];
int fa[200001], son[200001], top[200001];
int sz[200001], dfn[200001], cnt;
ll lsum, tmp[200001], degtmp[200001];

void add(int x, int y, int z) {
	e[++KK] = (node){z, y, le[x]}; le[x] = KK;
	e[++KK] = (node){z, x, le[y]}; le[y] = KK;
}

void dfs1(int now, int father) {//树链剖分预处理
	deg[now] = deg[father] + 1;
	fa[now] = father;
	sz[now] = 1;
	
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			tmp[e[i].to] = e[i].x;
			degtmp[e[i].to] = degtmp[now] + tmp[e[i].to];
			dfs1(e[i].to, now);
			sz[now] += sz[e[i].to];
			if (sz[e[i].to] > sz[son[now]]) son[now] = e[i].to;
		}
}

void dfs2(int now, int father) {
	dfn[now] = ++tmpp;
	dy[tmpp] = now;
	
	if (son[now]) {
		top[son[now]] = top[now];
		dfs2(son[now], now);
	}
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father && e[i].to != son[now]) {
			top[e[i].to] = e[i].to;
			dfs2(e[i].to, now);
		}
}

struct XDtree {//线段树
	ll a[800001], sum[800001];
	ll lzy[800001];
	
	void up(int now) {
		a[now] = a[now << 1] + a[now << 1 | 1];
		sum[now] = sum[now << 1] + sum[now << 1 | 1];
	}
	
	void down(int now) {
		if (!lzy[now]) return ;
		sum[now << 1] += a[now << 1] * lzy[now];
		sum[now << 1 | 1] += a[now << 1 | 1] * lzy[now];
		lzy[now << 1] += lzy[now];
		lzy[now << 1 | 1] += lzy[now];
		lzy[now] = 0;
	}
	
	void build(int now, int l, int r) {
		if (l == r) {
			a[now] = tmp[dy[l]];
			return ;
		}
		
		int mid = (l + r) >> 1;
		build(now << 1, l, mid);
		build(now << 1 | 1, mid + 1, r);
		up(now);
	}
	
	void insert(int now, int l, int r, int L, int R, ll t) {
		if (L <= l && r <= R) {
			sum[now] += a[now] * t;
			lzy[now] += t;
			return ;
		}
		
		down(now);
		int mid = (l + r) >> 1;
		if (L <= mid) insert(now << 1, l, mid, L, R, t);
		if (mid < R) insert(now << 1 | 1, mid + 1, r, L, R, t);
		up(now);
	}
	
	ll query(int now, int l, int r, int L, int R) {
		if (L <= l && r <= R) {
			return sum[now];
		}
		
		down(now);
		int mid = (l + r) >> 1;
		ll re = 0;
		if (L <= mid) re += query(now << 1, l, mid, L, R);
		if (mid < R) re += query(now << 1 | 1, mid + 1, r, L, R);
		return re;
	}
}T;

int main() {
//	freopen("express.in", "r", stdin);
//	freopen("express.out", "w", stdout);
	
	scanf("%d %d", &n, &m);
	for (int i = 1; i < n; i++) {
		scanf("%d %d %d", &x, &y, &z);
		add(x, y, z);
	}
	
	dfs1(1, 0);
	top[1] = 1;
	dfs2(1, 0);
	
	T.build(1, 1, n);
	while (m--) {
		scanf("%d %d", &op, &x);
		if (op == 1) {
			lsum += degtmp[x]; cnt++;
			while (x) {
				T.insert(1, 1, n, dfn[top[x]], dfn[x], 1);
				x = fa[top[x]];
			}
		}
		if (op == 2) {
			ll re = lsum + 1ll * cnt * degtmp[x];
			while (x) {
				re -= 2ll * T.query(1, 1, n, dfn[top[x]], dfn[x]);
				x = fa[top[x]];
			}
			printf("%lld\n", re);
		}
	}
	
	return 0;
}
posted @ 2021-10-27 07:11  あおいSakura  阅读(29)  评论(0编辑  收藏  举报