Luogu - P3384 树链剖分 [挂模板专用]
题意:请码个树剖模板支持子树区间加/查询和路径加/查询
纯练手
盲敲技能++
以后网络赛复制模板速度++++
对链操作时注意方向
#include<bits/stdc++.h>
#define rep(i,j,k) for(register int i=j;i<=k;i++)
#define rrep(i,j,k) for(register int i=j;i>=k;i--)
#define println(a) printf("%lld\n",(ll)a)
using namespace std;
const int MAXN = 2e5+11;
typedef long long ll;
ll MOD;
ll read(){
ll x=0,f=1;register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int to[MAXN<<1],nxt[MAXN<<1],head[MAXN],tot; ll cost[MAXN<<1];
void add(int u,int v,ll w=0){
to[tot]=v;
cost[tot]=w;
nxt[tot]=head[u];
head[u]=tot++;
}
int size[MAXN],dfn[MAXN],dfned[MAXN],pre[MAXN],p[MAXN],son[MAXN],dep[MAXN],top[MAXN],CLOCK;
void init(){
memset(head,-1,sizeof head);
memset(son,0,sizeof son);
tot=CLOCK=0;
}
void dfs(int u,int fa,int d){
p[u]=fa; dep[u]=d; size[u]=1;
son[u]=top[u]=0;
for(int i=head[u];~i;i=nxt[i]){
int v=to[i];
if(v==fa) continue;
dfs(v,u,d+1);
size[u]+=size[v];
if(size[v]>size[son[u]]){
son[u]=v;
}
}
}
void dfs2(int u,int tp){
top[u]=tp;
dfn[u]=++CLOCK;
pre[CLOCK]=u;
if(son[u]) dfs2(son[u],tp);
for(int i=head[u];~i;i=nxt[i]){
int v=to[i];
if(v==p[u]||v==son[u]) continue;
dfs2(v,v);
}
dfned[u]=CLOCK;
}
int a[MAXN];
struct ST{
ll sum[MAXN<<2];
int lazy[MAXN<<2];
#define lc o<<1
#define rc o<<1|1
void pu(int o){
sum[o]=sum[lc]+sum[rc];
if(sum[o]>=MOD) sum[o]%=MOD;
}
void build(int o,int l,int r){
sum[o]=lazy[o]=0;
if(l==r){
sum[o]=a[pre[l]]; //!!
return;
}
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
pu(o);
}
void pd(int o,int l,int r){
if(lazy[o]){
int mid=l+r>>1;
sum[lc]+=(ll)(mid-l+1)*lazy[o];
sum[rc]+=(ll)(r-mid)*lazy[o];
if(sum[lc]>MOD) sum[lc]%=MOD;
if(sum[rc]>MOD) sum[rc]%=MOD;
lazy[lc]+=lazy[o];
lazy[rc]+=lazy[o];
lazy[o]=0;
}
}
void update(int o,int l,int r,int L,int R,ll v){
if(L<=l&&r<=R){
sum[o]+=(ll)(r-l+1)*v;
if(sum[o]>=MOD) sum[o]%=MOD;
lazy[o]+=v;
return;
}
pd(o,l,r);
int mid=l+r>>1;
if(L<=mid) update(lc,l,mid,L,R,v);
if(R>mid) update(rc,mid+1,r,L,R,v);
pu(o);
}
ll query(int o,int l,int r,int L,int R){
if(L<=l&&r<=R){
return sum[o];
}
pd(o,l,r);
int mid=l+r>>1;
ll ans=0;
if(L<=mid) ans+=query(lc,l,mid,L,R);
if(R>mid) ans+=query(rc,mid+1,r,L,R);
return ans>=MOD?ans%MOD:ans;
}
}st;
void solve(int u,int v,int val,int op){
ll ans=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
if(op==2) ans+=st.query(1,1,CLOCK,dfn[top[u]],dfn[u]);
if(op==2&&ans>=MOD) ans%=MOD;
else st.update(1,1,CLOCK,dfn[top[u]],dfn[u],val);
u=p[top[u]];
}
//same link
if(dep[u]<dep[v]) swap(u,v);
if(op==2) ans+=st.query(1,1,CLOCK,dfn[v],dfn[u]);//att
if(op==2&&ans>=MOD) ans%=MOD;
else st.update(1,1,CLOCK,dfn[v],dfn[u],val);
if(op==2) println(ans%MOD);
}
int main(){
int n,m,rt;
while(cin>>n>>m>>rt>>MOD){
init();
rep(i,1,n) a[i]=read();
rep(i,1,n-1){
int u=read();
int v=read();
add(u,v);
add(v,u);
}
dfs(rt,-1,1); dfs2(rt,rt);
st.build(1,1,CLOCK);
while(m--){
int op=read();
if(op==1){
int u=read();
int v=read();
ll val=read();
solve(u,v,val,op);
}else if(op==2){
int u=read();
int v=read();
solve(u,v,0,op);
}else if(op==3){
int x=read();
ll v=read();
int l=dfn[x],r=dfned[x];
st.update(1,1,CLOCK,l,r,v);
}else{
int x=read();
int l=dfn[x],r=dfned[x];
println(st.query(1,1,CLOCK,l,r)%MOD);
}
}
}
return 0;
}