P3384 【模板】树链剖分
#include<bits/stdc++.h> // #include<iostream> // #include<stdio.h> // #include<iomanip> // #include<stack> // #include<queue> // #include<algorithm> // #include<cstring> // #include<map> // #include<vector> // #include<numeric> // #include<iterator> // #include<cmath> #define met(a,x) memset(a,x,sizeof(a)); #define lowbit(x) (x&(-x)) #define mid ((l + r) >> 1) #define len (r - l + 1) #define lson l,mid,rt<<1 #define rson mid+1,r,rt<<1|1 using namespace std; int gcd(int a, int b) { return b == 0 ? a : gcd(b, a%b); } int lcm(int a, int b) { return a * b / gcd(a, b); } typedef long long ll; typedef unsigned long long ull; typedef pair<int,int>Pi; typedef pair<ll, ll>Pii; const int inf = 0x3f3f3f3f; const ll INF = 0x3f3f3f3f3f3f3f3f; const double PI = acos(-1); map<int,int>mp; map<int, char *>mp1; map<char *, int>mp2; map<char, int>mp3; map<string,int>mp4; map<char,int>mp5; const int maxn = 400010; // const int mod=1000000007; ll a[maxn]; int dep[maxn],f[maxn],rev[maxn],id[maxn]; int siz[maxn],son[maxn],top[maxn],head[maxn]; int val[maxn],sum[maxn],lazy[maxn]; int cnt,tot,n,m,r,mod; int read(){ int flag=1; int sum=0; char c=getchar(); while(c<'0'||c>'9'){ if(c=='-')flag=-1; c=getchar(); } while(c>='0'&&c<='9'){ sum=sum*10+c-'0'; c=getchar(); } return sum*flag; } ll Read(){ int flag=1; ll sum=0; char c=getchar(); while(c<'0'||c>'9'){ if(c=='-')flag=-1; c=getchar(); } while(c>='0'&&c<='9'){ sum=sum*10+c-'0'; c=getchar(); } return sum*flag; } ll quickmul(ll a,ll b){ ll ans=0; while(b){ if(b&1){ ans=(ans+a)%mod; } a=(a+a)%mod; b>>=1; } return ans; } ll quickpow(ll a,ll b){ ll ans=1; while(b){ if(b&1) ans=(ans*a)%mod; a=(a*a)%mod; b>>=1; } return ans; } struct Edge{ int v,next; }e[maxn<<1]; void add(int u,int v){ e[tot] = (Edge){v,head[u]}; head[u] = tot++; } void dfs1(int u,int fa,int d){ f[u] = fa; dep[u] = d; siz[u] = 1; for(int i = head[u];i != -1;i=e[i].next){ int v = e[i].v; if(v == fa)continue; dfs1(v,u,d+1); siz[u] += siz[v]; if(siz[v] > siz[son[u]]){ son[u] = v; } } } void dfs2(int u,int t){ top[u] = t; id[u] = ++cnt; rev[cnt] = u; if(!son[u])return ; dfs2(son[u],t); for(int i = head[u];i != -1;i = e[i].next){ int v = e[i].v; if(v != son[u] && v != f[u]){ dfs2(v,v); } } } int LCA(int u,int v){// 用于求LCA while(top[u] != top[v]){ if(dep[top[u]] < dep[top[v]]){ // 优先跳深度深的 swap(u,v); } u = f[top[u]]; } if(dep[u] < dep[v]) return u; return v; } void pushup(int rt){ sum[rt] = (sum[rt<<1] + sum[rt<<1|1]) % mod; } void pushdown(int rt,int lenn){ if(lazy[rt]){ lazy[rt<<1] += lazy[rt]; lazy[rt<<1|1] += lazy[rt]; sum[rt<<1] += lazy[rt] * (lenn - (lenn >> 1)); sum[rt<<1|1] += lazy[rt] * (lenn >> 1); sum[rt] %= mod; sum[rt<<1|1] %= mod; lazy[rt] = 0; } } void build(int l,int r,int rt){ if(l == r){ lazy[rt] = 0; sum[rt] = val[rev[l]]; sum[rt] %= mod; return ; } build(lson); build(rson); pushup(rt); } void update(int L,int R,int w,int l,int r,int rt){ if(L <= l && R >= r){ lazy[rt] += w; sum[rt] += w * len; sum[rt] %= mod; return ; } pushdown(rt,len); if(L <= mid)update(L,R,w,lson); if(R >= mid + 1)update(L,R,w,rson); pushup(rt); } int querysum(int L,int R,int l,int r,int rt){ if(L <= l && R >= r)return sum[rt]; int ans=0; pushdown(rt,len); if(L <= mid)ans += querysum(L,R,lson); ans %= mod; if(R >= mid + 1)ans += querysum(L,R,rson); ans %= mod; return ans; } void update_Range(int x,int y,int z){ while(top[x] != top[y]){ if(dep[top[x]] < dep[top[y]]) swap(x,y); update(id[top[x]],id[x],z,1,n,1); x = f[top[x]]; } if(dep[x] > dep[y])swap(x,y); update(id[x],id[y],z,1,n,1); } void update_root(int x,int w){ update(id[x],id[x] + siz[x] - 1,w,1,n,1); } int sum_Range(int x,int y){ int ans = 0; while(top[x] != top[y]){ if(dep[top[x]] < dep[top[y]])swap(x,y); ans += querysum(id[top[x]],id[x],1,n,1); ans %= mod; x = f[top[x]]; } if(dep[x] > dep[y])swap(x,y); ans += querysum(id[x],id[y],1,n,1); ans %= mod; return ans; } int sum_root(int x){ return querysum(id[x],id[x] + siz[x] - 1,1,n,1); } int main() { n = read(),m = read(),r = read(),mod = read(); for(int i = 1;i <= n;i++)head[i] = -1; for(int i = 1;i <= n;i++)val[i] = read(); for(int i = 1;i <= n - 1;i++){ int u = read(); int v = read(); add(u,v); add(v,u); } dfs1(r,-1,1); dfs2(r,r); build(1,n,1); while(m--){ int op,x,y,z; op = read(); if(op == 1){ x = read(),y = read(),z = read(); z %= mod; update_Range(x,y,z); } else if(op == 2){ x = read(),y = read(); int ans = sum_Range(x,y); cout << ans << endl; } else if(op == 3){ x = read(),z = read(); z %= mod; update_root(x,z); } else { x = read(); int ans = sum_root(x); cout << ans << endl; } } return 0; }