P3384

P3384 【模板】重链剖分/树链剖分 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

](https://www.luogu.com.cn/problem/P3384)

#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int long long
#define mem(a) memset(a,0, sizeof(a))
#define db(x) cout<<x<<"\n";
#define _db(a,n) for(int i=1;i<=n;i++) cout<<a[i]<<" \n"[i==n];
#define rep(i,l,r) for(int i=l;i<=r;i++) 
#define per(i,r,l) for(int i=r;i>=l;i--)
#define lc p<<1
#define rc p<<1|1
const int N=1e5+5;
int fa[N],siz[N],son[N],dep[N],top[N];
vector<int>e[N];
//分别存父节点、当前节点的大小、当前节点的重儿子、当前节点的深度、当前节点所在重链的顶点、邻接表建图
//重儿子:当前节点的子节点中size最大的那个
//重链:由若干条首尾衔接的重边构成,把落单的结点——叶子节点也当作重链,那么整棵树就被剖分成若干条重链
//每个节点的轻儿子,一定作为重链的起点
void dfs1(int cur,int father){//树上的后序遍历
	fa[cur]=father;
	siz[cur]=1;
	dep[cur]=dep[father]+1;//当前节点的深度等于父节点的深度+1
	for(auto nxt:e[cur]){//枚举子节点
		if(nxt==father) continue;//防止无限递归
		dfs1(nxt,cur);//后序遍历下一个节点
		siz[cur]+=siz[nxt];//递归回溯后更新当前节点的siz
		if(siz[son[cur]]<siz[nxt]) son[cur]=nxt;//若遍历到的子节点size比重儿子size大,更新重儿子
	}
}
int tot,w[N],id[N],nw[N];//分别是时间戳、权值、当前节点的时间戳、新编号——时间戳为编号下的权值
void dfs2(int cur,int tp){//当前节点,当前节点的所在重链的顶点
	top[cur]=tp;//标记当前节点所在的重链顶点
	id[cur]=++tot,nw[tot]=w[cur];//记录时间戳、映射权值
	if(!son[cur]) return;//若当前节点没有重儿子(只要有一个子节点,其就是它的重儿子),说明当前节点是叶子节点
	dfs2(son[cur],tp);//若还有重儿子,则继续向下,标记重儿子的top
	for(auto nxt:e[cur]){//处理轻儿子
		if(nxt==fa[cur]||nxt==son[cur]) continue;//前者是避免无限递归,后者是为了避免重复处理重儿子
		dfs2(nxt,nxt);//处理轻儿子,轻儿子的重链顶点是它自身
	}
}
struct tree{
	int l,r,sum,lz;
}tr[N<<2];
void pushup(int p){
	tr[p].sum=(tr[lc].sum+tr[rc].sum);
}
void build(int p,int l,int r){
	tr[p].l=l,tr[p].r=r,tr[p].lz=0;
	if(l==r){
		tr[p].sum=nw[l];
		return;
	}
	int mid=l+r>>1;
	build(lc,l,mid);
	build(rc,mid+1,r);
	pushup(p);
}
void pushdown(int p){
	if(tr[p].lz){
		int k=tr[p].lz;
		tr[lc].sum=(tr[lc].sum+(tr[lc].r-tr[lc].l+1)*k);
		tr[rc].sum=(tr[rc].sum+(tr[rc].r-tr[rc].l+1)*k);
		tr[lc].lz=(tr[lc].lz+k);
		tr[rc].lz=(tr[rc].lz+k);
		tr[p].lz=0;
	}
}
int query(int p,int x,int y){
	if(x<=tr[p].l&&y>=tr[p].r){
		return tr[p].sum;
	}
    pushdown(p);
	int mid=tr[p].l+tr[p].r>>1;
	int res=0;
	if(x<=mid) res=(res+query(lc,x,y));
	if(y>mid) res=(res+query(rc,x,y));
	return res;
}
int query_path(int u,int v){
	int res=0;
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u,v);
		res=(res+query(1,id[top[u]],id[u]));
		u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u,v);//保证u是深度更大的点
	res=(res+query(1,id[v],id[u]));
	return res;
}
void update(int p,int x,int y,int k){
	if(x<=tr[p].l&&y>=tr[p].r){
		tr[p].sum=(tr[p].sum+(tr[p].r-tr[p].l+1)*k);
		tr[p].lz=(tr[p].lz+k);
		return;
	}
	pushdown(p);
	int mid=tr[p].l+tr[p].r>>1;
	if(x<=mid) update(lc,x,y,k);
	if(y>mid) update(rc,x,y,k);
	pushup(p);
}
void update_path(int u,int v,int k){//给u到v节点的最短路上的节点都加上k
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u,v);
		update(1,id[top[u]],id[u],k);
		u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u,v);
	update(1,id[v],id[u],k);
}
inline int read(){
    int a=0,b=1;
    char c=getchar();
    while(c<'0'||c>'9'){
        if(c=='-')b=-1;
        c=getchar();
    }while(c>='0'&&c<='9'){
        a=(a<<3)+(a<<1)+c-'0';
        c=getchar();
    }
    return a*b;
}
inline void write(int n){
    if(n<0){
        putchar('-');
        n=-n;
    }
    if(n>=10)
        write(n/10);
    putchar(n%10+'0');
}
void solve(){
	int n,m,r,M;n=read(),m=read(),r=read(),M=read();
    //scanf("%lld %lld %lld %lld",&n,&m,&r,&M);//cin>>n>>m>>r>>M;
	rep(i,1,n) w[i]=read();//scanf("%lld",&w[i]);//cin>>w[i];  
	rep(i,1,n-1){
		int u,v;u=read(),v=read();//scanf("%lld %lld",&u,&v);//cin>>u>>v;
		e[u].push_back(v);
        e[v].push_back(u);
	}
    dfs1(r,0);
    dfs2(r,r);
    build(1,1,n);
    // _db(nw,n);
    // _db(siz,n);
    // _db(dep,n);
	while(m--){
		int op,x,y,z;op=read(),x=read(); //scanf("%lld %lld",&op,&x);
        //cin>>op>>x;
		if(op==1){
            y=read(),z=read();
			// cin>>y>>z;
            // scanf("%lld %lld",&y,&z);
			update_path(x,y,z);
		}
		else if(op==2){
            y=read();
			// cin>>y;
            // scanf("%lld",&y);
            write(query_path(x,y)%M);
            printf("\n");
			// cout<<<<"\n";
		}
		else if(op==3){
            z=read();
			// cin>>z;
            // scanf("%lld",&z);
			update(1,id[x],id[x]+siz[x]-1,z);
		}
		else if(op==4){
            write(query(1,id[x],id[x]+siz[x]-1)%M);
            printf("\n");
			// cout<<<<"\n";
		}
	}
	//cout<<"#"<<(++_)<<" "<<endl;
}
signed main()
{
std::ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
	// int t;cin>>t;while(t--)
	solve();
	return 0;
}
posted @ 2024-08-06 15:53  mono_4  阅读(6)  评论(0编辑  收藏  举报