Luogu - P3384 树链剖分 [挂模板专用]

题意:请码个树剖模板支持子树区间加/查询和路径加/查询

纯练手

盲敲技能++

以后网络赛复制模板速度++++

对链操作时注意方向

#include<bits/stdc++.h>
#define rep(i,j,k) for(register int i=j;i<=k;i++)
#define rrep(i,j,k) for(register int i=j;i>=k;i--)
#define println(a) printf("%lld\n",(ll)a)
using namespace std;
const int MAXN = 2e5+11;
typedef long long ll;
ll MOD;
ll read(){
    ll x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int to[MAXN<<1],nxt[MAXN<<1],head[MAXN],tot; ll cost[MAXN<<1];
void add(int u,int v,ll w=0){
    to[tot]=v;
    cost[tot]=w;
    nxt[tot]=head[u];
    head[u]=tot++;
}
int size[MAXN],dfn[MAXN],dfned[MAXN],pre[MAXN],p[MAXN],son[MAXN],dep[MAXN],top[MAXN],CLOCK;
void init(){
    memset(head,-1,sizeof head);
    memset(son,0,sizeof son);
    tot=CLOCK=0;
}
void dfs(int u,int fa,int d){
    p[u]=fa; dep[u]=d; size[u]=1;
    son[u]=top[u]=0;
    for(int i=head[u];~i;i=nxt[i]){
        int v=to[i];
        if(v==fa) continue;
        dfs(v,u,d+1);
        size[u]+=size[v];
        if(size[v]>size[son[u]]){
            son[u]=v;
        }
    }
}
void dfs2(int u,int tp){
    top[u]=tp;
    dfn[u]=++CLOCK;
    pre[CLOCK]=u;
    if(son[u]) dfs2(son[u],tp);
    for(int i=head[u];~i;i=nxt[i]){
        int v=to[i];
        if(v==p[u]||v==son[u]) continue;
        dfs2(v,v);
    }
    dfned[u]=CLOCK; 
}
int a[MAXN];
struct ST{
    ll sum[MAXN<<2];
    int lazy[MAXN<<2];
    #define lc o<<1
    #define rc o<<1|1
    void pu(int o){
        sum[o]=sum[lc]+sum[rc];
        if(sum[o]>=MOD) sum[o]%=MOD;
    }
    void build(int o,int l,int r){
        sum[o]=lazy[o]=0;
        if(l==r){
            sum[o]=a[pre[l]]; //!!
            return;
        }
        int mid=l+r>>1;
        build(lc,l,mid);
        build(rc,mid+1,r);
        pu(o);
    }
    void pd(int o,int l,int r){
        if(lazy[o]){
            int mid=l+r>>1;
            sum[lc]+=(ll)(mid-l+1)*lazy[o];
            sum[rc]+=(ll)(r-mid)*lazy[o];
            if(sum[lc]>MOD) sum[lc]%=MOD;
            if(sum[rc]>MOD) sum[rc]%=MOD;
            lazy[lc]+=lazy[o];
            lazy[rc]+=lazy[o];
            lazy[o]=0;
        }
    }
    void update(int o,int l,int r,int L,int R,ll v){
        if(L<=l&&r<=R){
            sum[o]+=(ll)(r-l+1)*v;
            if(sum[o]>=MOD) sum[o]%=MOD;
            lazy[o]+=v;
            return;
        }
        pd(o,l,r);
        int mid=l+r>>1;
        if(L<=mid) update(lc,l,mid,L,R,v);
        if(R>mid) update(rc,mid+1,r,L,R,v);
        pu(o);
    }
    ll query(int o,int l,int r,int L,int R){
        if(L<=l&&r<=R){
            return sum[o];
        }
        pd(o,l,r);
        int mid=l+r>>1;
        ll ans=0;
        if(L<=mid) ans+=query(lc,l,mid,L,R);
        if(R>mid) ans+=query(rc,mid+1,r,L,R);
        return ans>=MOD?ans%MOD:ans;
    }
}st;
void solve(int u,int v,int val,int op){
    ll ans=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        if(op==2) ans+=st.query(1,1,CLOCK,dfn[top[u]],dfn[u]);
        if(op==2&&ans>=MOD) ans%=MOD;
        else st.update(1,1,CLOCK,dfn[top[u]],dfn[u],val);
        u=p[top[u]];
    }
    //same link
    if(dep[u]<dep[v]) swap(u,v);
    if(op==2) ans+=st.query(1,1,CLOCK,dfn[v],dfn[u]);//att
    if(op==2&&ans>=MOD) ans%=MOD;
    else st.update(1,1,CLOCK,dfn[v],dfn[u],val);
    if(op==2) println(ans%MOD);
}
int main(){
    int n,m,rt;
    while(cin>>n>>m>>rt>>MOD){
        init();
        rep(i,1,n) a[i]=read();
        rep(i,1,n-1){
            int u=read();
            int v=read();
            add(u,v);
            add(v,u);
        }
        dfs(rt,-1,1); dfs2(rt,rt);
        st.build(1,1,CLOCK);
        while(m--){
            int op=read();
            if(op==1){
                int u=read();
                int v=read();
                ll val=read();
                solve(u,v,val,op);
            }else if(op==2){
                int u=read();
                int v=read();
                solve(u,v,0,op);
            }else if(op==3){
                int x=read();
                ll v=read();
                int l=dfn[x],r=dfned[x];
                st.update(1,1,CLOCK,l,r,v);
            }else{
                int x=read();
                int l=dfn[x],r=dfned[x];
                println(st.query(1,1,CLOCK,l,r)%MOD);
            }
        }
    }
    return 0;
}
posted @ 2018-08-13 14:53  Caturra  阅读(131)  评论(0编辑  收藏  举报