BZOJ 3159: 决战 LCT+Splay

挺厉害的一道大数据结构题.    

由于 LCT 是维护树的形态的,所以说不支持翻转操作.  

而在维护序列时 splay 是支持区间翻转的.  

所以,我们对于 LCT 中每一个重链都维护一个 splay(这个不同于 LCT 中的 splay)     

由于重链是一个序列,所以是支持序列的区间翻转的.   

那么我们的翻转,链加和,链求和的操作就都在这个重链对应的 splay 上进行.   

然后这里一定要注意:我们在 LCT 中维护 LCT 中每个点对应到 splay 上的编号,只有 LCT 中的 splay 的根节点对应的是正确的编号.      

#include <cstdio>
#include <string>
#include <vector>
#include <cstring>
#include <algorithm> 
#define N 50007
#define ll long long
using namespace std;      
namespace IO
{
    void setIO(string s)
    {
        string in=s+".in";
        string out=s+".out";
        freopen(in.c_str(),"r",stdin);
        freopen(out.c_str(),"w",stdout);
    }
};
namespace Splay
{       
    #define lson s[x].ch[0]
    #define rson s[x].ch[1]   
    struct node
    {  
        int ch[2],f,rev,size;                  
        ll add,val,sum,Min,Max;  
    }s[N];   
    int sta[N];     
    int get(int x) { return s[s[x].f].ch[1]==x; } 
    void mark_rev(int x) { s[x].rev^=1,swap(lson,rson);} 
    void mark_add(int x,ll v)
    {
        s[x].add+=v;   
        s[x].sum+=1ll*s[x].size*v;   
        s[x].Min+=v,s[x].Max+=v,s[x].val+=v;
    }
    void pushup(int x)
    {    
        s[x].sum=s[x].Min=s[x].Max=s[x].val;
        s[x].size=s[lson].size+s[rson].size+1;     
        if(lson)
        {
            s[x].sum+=s[lson].sum;
            s[x].Min=min(s[x].Min,s[lson].Min);
            s[x].Max=max(s[x].Max,s[lson].Max);
        }
        if(rson)
        {
            s[x].sum+=s[rson].sum;
            s[x].Min=min(s[x].Min,s[rson].Min);
            s[x].Max=max(s[x].Max,s[rson].Max);
        }
    }  
    void pushdown(int x)
    {
        if(s[x].rev)
        {
            if(lson) mark_rev(lson);
            if(rson) mark_rev(rson); 
            s[x].rev=0;  
        }
        if(s[x].add)
        {
            if(lson) mark_add(lson,s[x].add);
            if(rson) mark_add(rson,s[x].add);  
            s[x].add=0; 
        }
 
    }
    void rotate(int x)
    { 
        int old=s[x].f,fold=s[old].f,which=get(x);     
        s[old].ch[which]=s[x].ch[which^1];
        if(s[old].ch[which]) s[s[old].ch[which]].f=old;
        s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold;   
        if(fold) s[fold].ch[s[fold].ch[1]==old]=x;  
        pushup(old),pushup(x);  
    }
    void splay(int x)
    {
        int fa,v=0,tmp=x; 
        for(;tmp;tmp=s[tmp].f) sta[++v]=tmp;           
        for(;v;--v) pushdown(sta[v]);        
        for(;fa=s[x].f;rotate(x))  
            if(s[fa].f)    
                rotate(get(fa)==get(x)?fa:x);     
    }     
    int get_kth(int x,int kth)
    {                
        pushdown(x);    
        if(kth<=s[lson].size) return get_kth(lson,kth);  
        else if(s[lson].size+1==kth) return x;  
        else return get_kth(rson,kth-s[lson].size-1);
    }          
    int findrt(int x)
    {
        while(s[x].f) { x=s[x].f; }
        return x;     
    }                
    #undef lson
    #undef rson
};      
#define ls s[x].ch[0]
#define rs s[x].ch[1]
struct node
{  
    int ch[2],f,rev,size;     
}s[N];  
int sta[N],rt[N];              
int get(int x) { return s[s[x].f].ch[1]==x; }
int Isr(int x) { return s[s[x].f].ch[0]!=x&&s[s[x].f].ch[1]!=x; }  
void mark(int x) { swap(ls,rs), s[x].rev^=1; } 
void pushup(int x) { s[x].size=s[ls].size+s[rs].size+1; }           
void pushdown(int x)
{
    if(s[x].rev)
    {
        s[x].rev=0;
        if(ls) mark(ls);
        if(rs) mark(rs);
    }
}
void rotate(int x)
{
    int old=s[x].f,fold=s[old].f,which=get(x);                      
    if(!Isr(old)) s[fold].ch[s[fold].ch[1]==old]=x;  
    s[old].ch[which]=s[x].ch[which^1];    
    if(s[old].ch[which]) s[s[old].ch[which]].f=old;   
    s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold; 
    pushup(old),pushup(x);  
}       
void splay(int x)
{
    int u=x,v=0,fa;
    for(sta[++v]=u;!Isr(u);u=s[u].f) sta[++v]=s[u].f;    
    rt[x]=rt[u];
    for(;v;--v) pushdown(sta[v]);    
    for(u=s[u].f;(fa=s[x].f)!=u;rotate(x))
        if(s[fa].f!=u)
            rotate(get(fa)==get(x)?fa:x); 
}
void Access(int x)
{   
    for(int y=0;x;y=x,x=s[x].f)
    {      
        splay(x);             
        if(rs)     // cut
        {                   
            rt[x]=Splay::get_kth(rt[x],s[ls].size+1);                         
            Splay::splay(rt[x]);          
            rt[rs]=Splay::s[rt[x]].ch[1];                  
            Splay::s[rt[rs]].f=0;
            Splay::s[rt[x]].ch[1]=0;               
            Splay::pushup(rt[x]);         
        } 
        if(y)        // link
        {           
            rt[x]=Splay::get_kth(rt[x],s[ls].size+1);  
            Splay::splay(rt[x]);     
            Splay::s[rt[x]].ch[1]=rt[y];    
            Splay::s[rt[y]].f=rt[x];   
            Splay::pushup(rt[x]);          
        }
        rs=y;
        pushup(x);
    }
}
void makeroot(int x) 
{ 
    Access(x),splay(x),mark(x),Splay::mark_rev(rt[x]); 
}
void split(int x,int y) 
{ 
    makeroot(x),Access(y),splay(y);
}
#undef ls
#undef rs 
int edges;
int hd[N],to[N<<1],nex[N<<1];   
void add(int u,int v)
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; 
}
void dfs(int u,int ff)
{
    s[u].f=ff;   
    rt[u]=u;      
    s[u].size=1;
    Splay::s[rt[u]].size=1;        
    for(int i=hd[u];i;i=nex[i])
    {
        int v=to[i];
        if(v==ff) continue; 
        dfs(v,u); 
    }
}
int main()
{
    // IO::setIO("input");
    int i,j,n,m,R;
    scanf("%d%d%d",&n,&m,&R);    
    for(i=1;i<n;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y),add(y,x); 
    }  
    dfs(1,0); 
    for(i=1;i<=m;++i)
    {     
        char op[10];  
        int x,y,z; 
        scanf("%s",op+1);  
        if(op[3]=='c')
        { 
            scanf("%d%d%d",&x,&y,&z);  
            split(x,y);        
            Splay::mark_add(rt[y],(ll)z);  
        }
        if(op[3]=='m')
        {       
            scanf("%d%d",&x,&y);   
            split(x,y);     
            printf("%lld\n",Splay::s[rt[y]].sum); 
        }
        if(op[3]=='j')
        {   
            scanf("%d%d",&x,&y); 
            split(x,y);      
            printf("%lld\n",Splay::s[rt[y]].Max);  
        }
        if(op[3]=='n')
        {
            scanf("%d%d",&x,&y);  
            split(x,y);  
            printf("%lld\n",Splay::s[rt[y]].Min);
        }
        if(op[3]=='v')
        {   
            scanf("%d%d",&x,&y); 
            split(x,y); 
            Splay::mark_rev(rt[y]);  
        }
    }
    return 0;
}

  

posted @ 2020-01-09 16:18  EM-LGH  阅读(173)  评论(0编辑  收藏  举报