CF1172E Nauuo and ODT LCT

自己独立想出来的,超级开心 

一开始想的是对于每一个点分别算这个点对答案的贡献. 

但是呢,我们发现由于每一条路径的贡献是该路径颜色种类数,而每个颜色可能出现多次,所以这样就特别不好算贡献. 

那么,还是上面那句话,由于算的是颜色种类,所以我们可以对每一个颜色种类单独算贡献. 

即不以点为单位去算,而是以颜色种类为单位去算.    

假设要算颜色为 $r$ 的贡献,那么必须保证每一个路径至少有一个端点在颜色 $r$ 构成的连通块中. 

这句话等同于不能出现两个端点都在非 $r$ 连通块的路径,即 $n^2-\sum_{col[i]\neq r}size[i]^2$   

对于每一个颜色都这么算就好了 ~ 

具体的话需要离线+撤销+LCT维护子树信息(就是那个平方和) 

然后还要用到那个点权转边权,每次只删除和父亲连边的那个套路 ~    

code: 

#include <cstdio> 
#include <vector>   
#include <cstring>  
#include <algorithm>  
#define N 600003  
#define LL long long  
#define lson t[x].ch[0] 
#define rson t[x].ch[1]         
#define setIO(s) freopen(s".in","r",stdin) ,freopen(s".out","w",stdout)   
using namespace std; 
LL ans,re[N];  
int edges;  
int fa[N],hd[N],to[N<<1],nex[N<<1],val[N],size[N],col[N],is[N];  
void add(int u,int v) 
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;   
}       
struct data 
{
    int u,v,tim;   
    data(int u=0,int v=0,int tim=0):u(u),v(v),tim(tim){}  
};   
vector<data>G[N];  
struct node  
{ 
    LL sqr;    
    int ch[2],rev,f,siz,son;   
}t[N];         
int get(int x) 
{
    return t[t[x].f].ch[1]==x; 
}
int isrt(int x) 
{
    return !(t[t[x].f].ch[0]==x||t[t[x].f].ch[1]==x);  
}
void pushup(int x) 
{
    t[x].siz=t[lson].siz+t[rson].siz+t[x].son+1;         
}
void rotate(int x) 
{
    int old=t[x].f,fold=t[old].f,which=get(x); 
    if(!isrt(old))   t[fold].ch[t[fold].ch[1]==old]=x;        
    t[old].ch[which]=t[x].ch[which^1],t[t[old].ch[which]].f=old; 
    t[x].ch[which^1]=old,t[old].f=x,t[x].f=fold; 
    pushup(old),pushup(x); 
}   
void splay(int x) 
{
    int u=x,fa; 
    for(;!isrt(u);u=t[u].f);            
    for(u=t[u].f;(fa=t[x].f)!=u;rotate(x))    
    {
        if(t[fa].f!=u)     
        { 
            rotate(get(fa)==get(x)?fa:x);  
        }
    }
}   
void Access(int x) 
{
    for(int y=0;x;y=x,x=t[x].f)    
    {
        splay(x); 
        if(rson) 
        {
            t[x].son+=t[rson].siz;            
            t[x].sqr+=(LL)t[rson].siz*t[rson].siz;       
        }
        if(y) 
        {
            t[x].son-=t[y].siz; 
            t[x].sqr-=(LL)t[y].siz*t[y].siz;    
        }
        rson=y;  
        pushup(x);  
    }
}   
void link(int x,int y) 
{                
    Access(x),splay(x);   
    t[y].f=x;   
    t[x].son+=t[y].siz;   
    t[x].sqr+=(LL)t[y].siz*t[y].siz;                  
    pushup(x);   
}
// x 与 x 的父亲 
void cut(int x) 
{ 
    Access(x),splay(x);   
    if(lson)       
    {
        t[lson].f=0;   
        lson=0;           
        pushup(x);  
    }   
}
int findroot(int x) 
{   
    Access(x),splay(x);    
    while(lson)  x=lson;    
    return x;  
}         
void turn_0(int x) 
{             
    Access(x),splay(x);  
    int now=t[x].siz;      
    ans-=t[x].sqr;                     
    if(fa[x])  link(fa[x],x);       
    int p=findroot(x);                  
    splay(p);       
    is[x]=0;                   
    p=is[p]?t[p].ch[1]:p;          
    int ori=t[p].siz;         
    ans-=(LL)(ori-now)*(ori-now);   
    ans+=(LL)ori*ori;         
}
void turn_1(int x) 
{                      
    int p=findroot(x);          
    splay(p);                              
    p=is[p]?t[p].ch[1]:p;                      
    int ori=t[p].siz;    
    ans-=(LL)ori*ori;    
    cut(x);    
    int now=t[x].siz;                   
    ans+=(LL)(ori-now)*(ori-now);        
    ans+=(LL)t[x].sqr;            
    is[x]=1; 
}
void dfs(int u,int ff) 
{    
    size[u]=1;
    fa[u]=t[u].f=ff;        
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i]; 
        if(v==ff)    continue;   
        dfs(v,u);    
        size[u]+=size[v];  
        t[u].son+=size[v];  
        t[u].sqr+=(LL)size[v]*size[v];  
    }   
    pushup(u);  
}
int main() 
{ 
    // setIO("input"); 
    int i,j,n,m,mx=0;   
    scanf("%d%d",&n,&m);                  
    for(i=1;i<=n;++i)   
    {
        scanf("%d",&col[i]); 
        val[i]=col[i]; 
        mx=max(mx,col[i]);    
        G[val[i]].push_back(data(i,val[i],0));        
    }
    for(i=1;i<n;++i)   
    {
        int u,v;  
        scanf("%d%d",&u,&v);    
        add(u,v),add(v,u);  
    }   
    dfs(1,0);    
    for(i=1;i<=m;++i) 
    {
        int u,v; 
        scanf("%d%d",&u,&v);      
        mx=max(mx,v);  
        if(val[u]==v)   continue;   
        G[val[u]].push_back(data(u,v,i));       //  val[u]->v  
        G[v].push_back(data(u,v,i));            //  ?->v       
        val[u]=v;  
    }               
    for(i=1;i<=mx;++i) 
    { 
        ans=(LL)n*n; 
        LL pre=0;               
        for(j=0;j<G[i].size();++j) 
        {                
            if(G[i][j].v==i)             // 别的变成 i   (0->1)
            {           
                turn_1(G[i][j].u);            
            }
            else                         // i 变成别的   (1->0)
            {  
                turn_0(G[i][j].u);      
            }                                    
            re[G[i][j].tim]-=pre;   
            re[G[i][j].tim]+=(LL)n*n-ans;   
            pre=(LL)n*n-ans;   
        }                
        for(j=G[i].size()-1;j>=0;--j) 
        {
            if(G[i][j].v==i)             // 别的变成 i   (0->1)
            {           
                turn_0(G[i][j].u);            
            }
            else                         // i 变成别的   (1->0)
            {  
                turn_1(G[i][j].u);      
            }   
        }        
    }    
    printf("%lld\n",re[0]);   
    for(i=1;i<=m;++i)  re[i]+=re[i-1], printf("%lld\n",re[i]); 
    return 0; 
}   

  

posted @ 2019-12-16 11:31  EM-LGH  阅读(149)  评论(0编辑  收藏  举报