Loading

P3384 【模板】轻重链剖分

P3384 【模板】轻重链剖分

#include<bits/stdc++.h>

using namespace std;

int readint(){
	int x = 0;
	int f = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9') {
		if(ch == '-') f = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9'){
		x = x * 10 + ch - '0';
		ch = getchar();
	}
	return x * f;
}

int MOD;
const int maxn  = 2e5 + 5;
struct Tree{
	int sum;
	int lazy;
};

Tree node[maxn << 2];
int a[maxn];
int id[maxn],pre[maxn],top[maxn];
int dep[maxn],fa[maxn],son[maxn];
int siz[maxn];
vector<int> e[maxn];
int m,n,r;

void push_up(int i){
	node[i].sum = node[i << 1].sum + node[i << 1|1].sum;
	node[i].sum %= MOD;
}

void push_down(int i,int m){
	node[i << 1].lazy += node[i].lazy;
	node[i << 1].lazy %= MOD;
	node[i << 1|1].lazy += node[i].lazy;
	node[i << 1|1].lazy %= MOD;
	node[i << 1].sum += node[i].lazy * (m - (m >> 1));
	node[i << 1].sum %= MOD;
	node[i << 1|1].sum += node[i].lazy * (m >> 1);
	node[i << 1|1].sum %= MOD;
	node[i].lazy = 0;
}

void build(int i,int l,int r){
	if(l == r){
		node[i].sum = a[id[l]];
		return;
	} 
	int mid = l + r >> 1;
	build(i << 1,l,mid);
	build(i << 1|1,mid + 1,r); 
	push_up(i); 
}

int res;

void update(int i,int l,int r,int L,int R,int v){
	if(l >= L && r <= R) {
		node[i].lazy += v;
		node[i].sum += v * (r - l + 1) % MOD;
		node[i].sum %= MOD;
		return;
	}
	if(node[i].lazy) push_down(i,r - l +1 );
	int mid = l + r >> 1;
	if(L <= mid) update(i << 1,l,mid,L,R,v);
	if(R > mid) update(i << 1|1,mid + 1,r,L,R,v); 
	push_up(i);
}


void query(int i,int l,int r,int L,int R){
	if(l >= L && r <= R){
		res += node[i].sum;
		res %= MOD;
		return;
	}
	if(node[i].lazy) push_down(i,r - l + 1);
	int mid = l + r >> 1;
	if(L <= mid) query(i << 1,l,mid,L,R);
	if(R > mid) query(i << 1|1,mid + 1,r,L,R);
}

void dfs1(int u){
	siz[u] = 1;
	for(auto v:e[u]){
		if(v == fa[u]) continue;
		fa[v] = u;
		dep[v] = dep[u] + 1;
		dfs1(v);
		siz[u] += siz[v];
		if(siz[v] > siz[son[u]])
			son[u] = v;
	}
}

int tt;

void dfs2(int u,int x){
	pre[u] = ++tt;
	id[tt] = u;
	top[u] = x;
	if(!son[u]) return;
	dfs2(son[u],x);
	for(auto v:e[u]){
		if(v == fa[u] || v == son[u]) continue;
		dfs2(v,v);
	}
}

int qRange(int x,int y){
	int ans = 0;
	while(top[x] != top[y]){
		if(dep[top[x]] < dep[top[y]]) swap(x,y);
		res = 0;
		query(1,1,n,pre[top[x]],pre[x]);
		ans += res;
		ans %= MOD;
		x = fa[top[x]];
	}
	if(dep[x] > dep[y]) swap(x,y);
	res = 0;
	query(1,1,n,pre[x],pre[y]);
	ans += res;
	return  ans % MOD;
}

int qSon(int x){
	res = 0;
	query(1,1,n,pre[x],pre[x] + siz[x] - 1);
	return res;
}

void updRange(int x,int y,int k){
	k %= MOD;
	while(top[x] != top[y]){
		if(dep[top[x]] < dep[top[y]]) swap(x,y);
		update(1,1,n,pre[top[x]],pre[x],k);
		x = fa[top[x]];
	}
	if(dep[x] > dep[y]) swap(x,y);
	update(1,1,n,pre[x],pre[y],k);
}

void updSon(int x,int k){
	update(1,1,n,pre[x],pre[x] + siz[x] - 1,k);
}

int main(){
	n = readint();
	m = readint();
	r = readint();
	MOD = readint();
	for(int i = 1;i <= n;i++)
		a[i] =  readint();
	for(int i = 1;i < n;i++){
		int x = readint();
		int y = readint();
		e[x].push_back(y);
		e[y].push_back(x);
	}
	dfs1(r);
	dfs2(r,r);
	build(1,1,n);
	while(m--){
		int k,x,y,z;
		k = readint();
		if(k == 1){
			x = readint();
			y = readint();
			z = readint();
			updRange(x,y,z);
		}
		else if(k == 2){
			x = readint();
			y = readint();
			printf("%d\n",qRange(x,y));
		}
		else if(k == 3){
			x = readint();
			y = readint();
			updSon(x,y);
		}
		else {
			x = readint();
			printf("%d\n",qSon(x));
		}
	} 
}
posted @ 2020-11-02 22:06  MQFLLY  阅读(129)  评论(0编辑  收藏  举报