题解:

lct+链上修改

每一次修改的时候记录lazy标记

如果有了乘法,加法的lazy标记也要相应的随之变化

代码:

#pragma GCC optimize(2)
#include<bits/stdc++.h>
const int M=51061,N=100010;
typedef unsigned ll;
using namespace std;
int read()
{
    int x=0;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar());
    for (;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x;
}
int n,m,top,cnt,c[N][2],fa[N],size[N],q[N],rev[N];
ll sum[N],val[N],at[N],mt[N];
void cal(int x,int m,int a)
{
    if (!x)return;
    val[x]=(val[x]*m+a)%M;
    sum[x]=(sum[x]*m+a*size[x])%M;
    at[x]=(at[x]*m+a)%M;
    mt[x]=(mt[x]*m)%M;
}
int isroot(int x){return c[fa[x]][0]!=x&&c[fa[x]][1]!=x;}
void update(int x)
{
    int l=c[x][0],r=c[x][1];
    sum[x]=(sum[l]+sum[r]+val[x])%M;
    size[x]=(size[l]+size[r]+1)%M;
}
void pushdown(int x)
{
    int l=c[x][0],r=c[x][1];
    if (rev[x])
     {
        rev[x]^=1;rev[l]^=1;rev[r]^=1;
        swap(c[x][0],c[x][1]);
     }
    int m=mt[x],a=at[x];
    mt[x]=1;at[x]=0;
    if(m!=1||a!=0)cal(l,m,a),cal(r,m,a);
}
void rotate(int x)
{
    int y=fa[x],z=fa[y],l,r;
    l=(c[y][1]==x);r=l^1;
    if (!isroot(y))c[z][c[z][1]==y]=x;
    fa[x]=z;fa[y]=x;fa[c[x][r]]=y;
    c[y][l]=c[x][r];c[x][r]=y;
    update(y);update(x);
}
void down(int x){if (!isroot(x))down(fa[x]);pushdown(x);}
void splay(int x)
{
    down(x);
    for (int y=fa[x];!isroot(x);rotate(x),y=fa[x])
     if (!isroot(y))rotate((c[y][0]==x)==(c[fa[y]][0]==y)?y:x);
}
void access(int x)
{
    for(int t=0;x;t=x,x=fa[x])
     {
        splay(x);
        c[x][1]=t;
        update(x);
     }
}
void makeroot(int x){access(x);splay(x);rev[x]^=1;}
void split(int x,int y){makeroot(y);access(x);splay(x);}
void link(int x,int y){makeroot(x);fa[x]=y;}
void cut(int x,int y){makeroot(x);access(y);splay(y);c[y][0]=fa[x]=0;}
int main()
{
    n=read();m=read();
    for (int i=1;i<=n;i++)val[i]=sum[i]=mt[i]=size[i]=1;
    for (int i=1;i<n;i++)
     {
        int u=read(),v=read();
        link(u,v);
     }
    char ch[5];
    while(m--)
     {
        scanf("%s",ch);
        int u=read(),v=read();
        if (ch[0]=='+')
         {
            int c=read();
            split(u,v);cal(u,1,c);
         }
        if (ch[0]=='-')
         {
            cut(u,v);
            u=read();v=read();link(u,v);
         }
        if (ch[0]=='*')
         {
            int c=read();
            split(u,v);cal(u,c,0);
         }
        if (ch[0]=='/')
         {
            split(u,v);
            printf("%d\n",sum[u]);
         }
     }
    return 0;
}

 

posted on 2017-12-11 18:09  宣毅鸣  阅读(134)  评论(0编辑  收藏  举报