P3384 【模板】轻重链剖分

树链剖分,将树上的一段路径划分为log条重链,用线段树统计答案。

dfs2时先遍历重儿子,遍历轻儿子时注意判重,注意重新分配的编号 。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>

using namespace std;

const int N = 1e5 + 100;
struct Edge {
	int v, nxt;
} e[N << 1];
struct Node {
	int l, r;
	int tag;
	int sum;
} tr[N << 2];

int n, m, cnt, head[N], weight[N], wgt[N], rt, mod;
int siz[N], fa[N], dep[N], son[N], top[N], id[N], idx;

void AddEdge(int u, int v) {
	e[++cnt].v = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
} 

void dfs1(int u, int ff) {
	fa[u] = ff;
	dep[u] = dep[ff] + 1;
	siz[u] = 1;
	int maxson = 0;
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].v;
		if( v == fa[u])	continue;
		dfs1(v, u);
		siz[u] += siz[v];
		if( siz[v] > maxson)
			son[u] = v;
	}
}

void dfs2(int u, int topf) {
	top[u] = topf;
	id[u] = ++ idx;
	wgt[idx] = weight[u];
	if( son[u])	dfs2(son[u], topf);
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].v;
		if( v == fa[u] || v == son[u])	continue;
		dfs2(v, v);
	}
}

void pushup(int u) {
	tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
}

void build(int u, int l, int r) {
	tr[u].l = l, tr[u].r = r;
	if( l == r) {
		tr[u].sum = wgt[l];
		return;
	}
	int mid = l + r >> 1;
	build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
	pushup(u); 
} 

void pushdown(int u) {
	if( tr[u].tag) {
		Node & ls = tr[u << 1], &rs = tr[u << 1 | 1];
		ls.sum = (ls.sum + (ls.r - ls.l + 1) * tr[u].tag) % mod;
		ls.tag = (tr[u].tag + ls.tag) % mod;
		rs.sum = (rs.sum + (rs.r - rs.l + 1) * tr[u].tag) % mod;
		rs.tag = (tr[u].tag + rs.tag) % mod;
		tr[u].tag = 0;
	}
}

void change(int u, int l, int r, int v) {
	if( l <= tr[u].l && r >= tr[u].r) {
		tr[u].tag = (tr[u].tag + v) % mod;
		tr[u].sum = (tr[u].sum + v * (tr[u].r - tr[u].l + 1)) % mod; 
		return;
	}
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if( l <= mid) 	change(u << 1, l, r, v);
	if( r > mid)	change(u << 1 | 1, l, r, v);
	pushup(u); 
}

int ask(int u, int l, int r) {
	if( l <= tr[u].l && r >= tr[u].r) {
		return tr[u].sum;
	}
	pushdown(u);
	int res = 0;
	int mid = tr[u].l + tr[u].r >> 1;
	if( l <= mid)	res += ask(u << 1, l, r);
	if( r > mid)	res = (res + ask(u << 1 | 1, l, r)); 
	return res % mod;
}

int query(int x, int y) {
	int res = 0;
	int tp1 = top[x], tp2 = top[y];
	while( top[x] != top[y]) {
		if(dep[tp1] < dep[tp2])  swap(tp1, tp2), swap(x, y);
		res = (res + ask(1, id[tp1], id[x])) % mod;
		x = fa[tp1];
		tp1 = top[x];
	}
	if( dep[x] < dep[y])	swap(x, y);
	return (res + ask(1, id[y], id[x])) % mod;
}

void modify(int x, int y, int v) {
	int tp1 = top[x], tp2 = top[y];
	while(top[x] != top[y]) {
		if( dep[tp1] < dep[tp2])	swap(x, y), swap(tp1, tp2);
		change(1, id[tp1], id[x], v);
		x = fa[tp1];
		tp1 = top[x]; 
	}
	if( dep[x] < dep[y])	swap(x, y);
	change(1, id[y], id[x], v);
} 

int main()
{
	scanf("%d%d%d%d", &n, &m, &rt, &mod);
	for(int i = 1; i <= n; i ++)
		scanf("%d", &weight[i]);
	for(int i = 1; i < n; i ++) {
		int u, v;
		scanf("%d%d", &u, &v);
		AddEdge(u, v), AddEdge(v, u); 
	} 
	dfs1(rt, 0);
	dfs2(rt, rt);
	build(1, 1, n);
	while(m --) {
		int opt, x, y, z;
		scanf("%d", &opt);
		if( opt == 1) {
			scanf("%d%d%d", &x, &y, &z);
			modify(x, y, z);
		}
		else if( opt == 2) {
			scanf("%d%d", &x, &y);
			printf("%d\n", query(x, y));
		}
		else if( opt == 3) {
			scanf("%d%d", &x, &z);
			change(1, id[x], id[x] + siz[x] - 1, z);
		}
		else {
			scanf("%d", &x);
			printf("%d\n", ask(1, id[x], id[x] + siz[x] - 1));
		}
	}
	return 0;
}
posted @ 2020-10-06 14:20  王雨阳  阅读(103)  评论(0编辑  收藏  举报