树剖模板(洛谷3384)

#include<bits/stdc++.h>

using namespace std;


const int maxn=5e5+10;

#define ll long long
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
#define rt 1,n,1
#define ls now<<1
#define rs now<<1|1
#define lson l,mid,ls
#define rson mid+1,r,rs

struct edge{
    int u,v,next;
}e[maxn];

struct node{
    int f,d,s,son,rk,top,w,maxid;
}N[maxn];
//  f父节点 d深度 s节点合数 son重儿子 top重链的头节点 w点权
//  maxid 以此点为根的节点里dfs序最大的序
// rk dfs序
int head[maxn],id[maxn],v[maxn],tot=0;
// id dfs序
// v N[id[i]].w
int tree[maxn*4],lazy[maxn*4],mod;

int n,m,r;

void creatEdge(int u,int v)
{
    e[++tot]=(edge){u,v,head[u]};
    head[u]=tot;
}
// f d s
void dfs1(int u,int fa,int dep)
{
    N[u].f=fa;
    N[u].d=dep;
    N[u].s=1;
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].v;
        if(v==u||N[v].d) continue;
        dfs1(v,u,dep+1);
        N[u].s+=N[v].s;
        if(N[v].s>N[N[u].son].s)
            N[u].son=v;
    }
}

int mub=0;
// top  id rk maxid;
void dfs2(int u,int fa,int t)
{
    N[u].top=t;
    id[u]=++mub;
    v[mub]=N[u].w;
    N[u].maxid=mub;
    if(!N[u].son) return ;

    dfs2(N[u].son,u,t);


    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].v;
        if(v==N[u].son||v==fa||N[v].d!=N[u].d+1) continue;
        dfs2(v,u,v);
    }
    N[u].maxid=mub;

}

void built(int l,int r,int now)
{
    int mid=(l+r)>>1;
    if(l==r)
    {
        tree[now]=v[l];
        return ;
    }
    built(lson);
    built(rson);
    tree[now]=tree[ls]+tree[rs];
    tree[now]%=mod;
}

void pushdown(int l,int r,int now)
{
    int mid=(l+r)>>1;
    lazy[ls]+=lazy[now];
    lazy[rs]+=lazy[now];

    lazy[ls]%=mod;
    lazy[rs]%=mod;

    tree[ls]+=(mid-l+1)*lazy[now];
    tree[rs]+=(r-mid)*lazy[now];

    tree[ls]%=mod;
    tree[rs]%=mod;

    lazy[now]=0;
}

void updata(int l,int r,int now,int left,int right,int k)
{
    int mid=(l+r)>>1;
    if(left<=l&&r<=right)
    {
        tree[now]+=(r-l+1)*k;
        lazy[now]+=k;

        tree[now]%=mod;
        lazy[now]%=mod;

        return ;
    }
    if(lazy[now])
        pushdown(l,r,now);
    if(left<=mid) updata(lson,left,right,k);
    if(right>mid) updata(rson,left,right,k);

    tree[now]=tree[ls]+tree[rs];
    tree[now]%=mod;
}

void cmd1(int x,int y,int k)
{
    int fx=N[x].top,fy=N[y].top;
    while(fx!=fy)
    {
        if(N[fx].d>=N[fy].d)
        {
            updata(rt,id[fx],id[x],k);
            x=N[fx].f;
            fx=N[x].top;
        }
        else
        {
            updata(rt,id[fy],id[y],k);
            y=N[fy].f;
            fy=N[y].top;
        }
    }
    if(id[x]<=id[y])
        updata(rt,id[x],id[y],k);
    else
        updata(rt,id[y],id[x],k);
}

int query(int l,int r,int now,int left,int right)
{
    int mid=(l+r)>>1;
    if(left<=l&&r<=right)
    {
        return tree[now];
    }
    if(lazy[now])
        pushdown(l,r,now);
    int ans=0;
    if(left<=mid) ans+=query(lson,left,right);
    if(right>mid) ans+=query(rson,left,right);

    tree[now]=tree[ls]+tree[rs];
    tree[now]%=mod;

    ans%=mod;
    return ans;
}

int cmd2(int x,int y)
{
    int fx=N[x].top,fy=N[y].top;
    int ans=0;

    while(fx!=fy)
    {
        if(N[fx].d>=N[fy].d)
        {
            ans+=query(rt,id[fx],id[x]);
            x=N[fx].f,fx=N[x].top;
        }
        else
        {
            ans+=query(rt,id[fy],id[y]);
            y=N[fy].f,fy=N[y].top;
        }
    }
    ans%=mod;
    if(id[x]<=id[y])
        ans+=query(rt,id[x],id[y]);
    else
        ans+=query(rt,id[y],id[x]);
    ans%=mod;


    return ans;
}

void cmd3(int x,int k)
{
    updata(rt,id[x],N[x].maxid,k);
}

int main()
{
    scanf("%d%d%d%d",&n,&m,&r,&mod);
    for(int i=1;i<=n;i++)
        scanf("%d",&N[i].w);
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        creatEdge(u,v);
        creatEdge(v,u);
    }
    dfs1(r,r,1);
    dfs2(r,0,r);
    built(rt);
    while(m--)
    {
        int cmd,x,y,z;
        scanf("%d%d",&cmd,&x);
        if(cmd==1)
        {
            scanf("%d%d",&y,&z);
            cmd1(x,y,z);
        }
        if(cmd==2)
        {
            scanf("%d",&y);
            printf("%d\n",cmd2(x,y));
        }
        if(cmd==3)
        {
            scanf("%d",&z);
            cmd3(x,z);
        }
        if(cmd==4)
        {
            printf("%d\n",query(rt,id[x],N[x].maxid));
        }
   //     cout<<"end"<<endl;
    }
    return 0;
}

 

posted @ 2019-06-13 19:31  Minun  阅读(165)  评论(0编辑  收藏  举报