bzoj2631: tree lct

这道题需要注意的就是标记的传递 规定先传乘法标记再传加法标记就好了

#include<cstdio>
#include<cstring>
#include<algorithm>
#define LL unsigned int
using namespace std;
const int M=100007,mod=51061; 
int read(){
    int ans=0,f=1,c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();}
    return ans*f;
}
LL v[M],sum[M],tag[M],mt[M];
int c[M][2],fa[M],rev[M],size[M];
int n,m;
bool isrt(int x){return c[fa[x]][0]!=x&&c[fa[x]][1]!=x;}
void up(int x){
    int l=c[x][0],r=c[x][1];
    sum[x]=(sum[l]+sum[r]+v[x])%mod;
    size[x]=size[l]+size[r]+1;
}
void add(int x,int m,int w){
    if(!x) return ;
    v[x]=(v[x]*m+w)%mod;
    sum[x]=(sum[x]*m+w*size[x])%mod;
    tag[x]=(tag[x]*m+w)%mod;
    mt[x]=mt[x]*m%mod;
}
void down(int x){
    int l=c[x][0],r=c[x][1];
    if(rev[x]){
        rev[x]=0; rev[l]^=1; rev[r]^=1;
        swap(c[x][0],c[x][1]);
    }
    if(tag[x]||mt[x]!=1){
        add(l,mt[x],tag[x]);
        add(r,mt[x],tag[x]);
        tag[x]=0; mt[x]=1;
    }
}
void rotate(int x){
    int y=fa[x],z=fa[y],l=0,r=1;
    if(c[y][1]==x) l=1,r=0;
    if(!isrt(y)) c[z][c[z][1]==y]=x;
    fa[y]=x; fa[x]=z; fa[c[x][r]]=y;
    c[y][l]=c[x][r]; c[x][r]=y;
    up(y); up(x);
}
int st[M],top;
void splay(int x){
    st[++top]=x; for(int i=x;!isrt(i);i=fa[i]) st[++top]=fa[i];
    while(top) down(st[top--]);
    while(!isrt(x)){
        int y=fa[x],z=fa[y];
        if(!isrt(y)){
            if(c[z][0]==y^c[y][0]==x) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
}
void acs(int x0){
    for(int x=x0,y=0;x;splay(x),c[x][1]=y,up(x),y=x,x=fa[x]);
    splay(x0);
}
void mrt(int x){acs(x); rev[x]^=1;}
void link(int x,int y){mrt(x); fa[x]=y;}
void split(int x,int y){mrt(x); acs(y);}
void cut(int x,int y){split(x,y); c[y][0]=fa[x]=0;}
int main()
{
    int x,y,w;
    char ch[5];
    n=read(); m=read();
    for(int i=1;i<=n;i++) size[i]=v[i]=mt[i]=sum[i]=1;
    for(int i=1;i<n;i++) x=read(),y=read(),link(x,y);
    for(int i=1;i<=m;i++){
        scanf("%s",ch); x=read(); y=read();
        if(ch[0]=='+') w=read(),split(y,x),add(x,1,w);
        else if(ch[0]=='-') cut(x,y),x=read(),y=read(),link(x,y);
        else if(ch[0]=='*') w=read(),split(y,x),add(x,w,0);
        else split(y,x),printf("%d\n",sum[x]);
    }
    return 0;
}
View Code

 

posted @ 2017-06-13 18:41  友人Aqwq  阅读(107)  评论(0编辑  收藏  举报