【洛谷P2664】 树上游戏 点分治

code:

#include <bits/stdc++.h>   
 
#define N 200009   
 
#define ll long long 
 
#define setIO(s) freopen(s".in","r",stdin)   
 
using namespace std; 
 
ll Sum[N];
 
int n,edges,root,sn;     
 
int val[N],hd[N],to[N<<1],nex[N<<1],size[N],mx[N],vis[N],A[N];   
 
inline void add(int u,int v) 
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;  
}     
 
void getroot(int u,int ff) 
{
    size[u]=1,mx[u]=0; 
 
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i];   
 
        if(v==ff||vis[v])    continue;
 
        getroot(v,u);  
 
        size[u]+=size[v];   
 
        mx[u]=max(mx[u],size[v]);  
    }
 
    mx[u]=max(mx[u],sn-size[u]); 
 
    if(mx[u]<mx[root])   root=u;  
 
}        
  
int ou;     
 
ll tmp,tot,bu[N];           
 
map<int,ll>cn[N];     
 
map<int,ll>::iterator it;   
 
int dep[N],cnt[N],siz[N];   
 
void getnode(int top,int u,int ff,int cur) 
{               

    if(!cnt[val[u]])   
 
        ++cur;      
 
    ++cnt[val[u]];  
 
    Sum[top]+=(ll)cur;
 
    siz[u]=1;      
 
    for(int i=hd[u];i;i=nex[i])  
    {
        int v=to[i]; 
 
        if(v==ff||vis[v])     continue;    
 
        getnode(top,v,u,cur);  
 
        siz[u]+=siz[v];  
    }      
 
    --cnt[val[u]];                            
}            
void get_col(int top,int u,int ff) 
{                 
    if(!cnt[val[u]])     
 
        cn[top][val[u]]+=(ll)siz[u];             
 
    ++cnt[val[u]];     
 
    for(int i=hd[u];i;i=nex[i]) 
    {
 
        int v=to[i]; 
 
        if(v==ff||vis[v])   continue;      
 
        get_col(top,v,u); 
 
    } 
    --cnt[val[u]];        
}          
void calc_v(int u,int ff) 
{          
    ll tt=bu[val[u]];           
 
    tmp=tmp-bu[val[u]]+ou;            
 
    bu[val[u]]=ou;                    

    Sum[u]+=tmp;       

    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i]; 
 
        if(vis[v]||v==ff)    continue;          
 
        calc_v(v,u);                       
    }  

    tmp=tmp-bu[val[u]]+tt;    

    bu[val[u]]=tt;    
}
void clr(int u,int ff)
{
    cn[u].clear();    
    bu[val[u]]=0;         
    for(int i=hd[u];i;i=nex[i])
    {
        int v=to[i]; 
        if(v==ff||vis[v])   continue;   
        clr(v,u);     
    }
}
void calc(int u) 
{         
    tot=0;       
 
    getnode(u,u,0,0);       

    for(int i=hd[u];i;i=nex[i])           
    {
        int v=to[i];   
 
        if(vis[v])    continue;     
 
        // memset(cnt,0,sizeof(cnt));  
 
        get_col(v,v,u);    
 
        for(it=cn[v].begin();it!=cn[v].end();it++)    
        {
            tot+=it->second; 
 
            bu[it->first]+=it->second;                
        }
    }               

    for(int i=hd[u];i;i=nex[i])    
    {   
        int v=to[i];   
 
        if(vis[v])    continue;   
 
        tmp=tot;       
 
        ou=siz[u]-siz[v]; 
 
        for(it=cn[v].begin();it!=cn[v].end();it++)    
        {
            bu[it->first]-=it->second;   
 
            tmp-=it->second;            
        }

        ll tt=bu[val[u]];             
 
        tmp=tmp-bu[val[u]]+ou; 
 
        bu[val[u]]=ou;  
 
        calc_v(v,u);     
  
        bu[val[u]]=tt;   
 
        for(it=cn[v].begin();it!=cn[v].end();it++) bu[it->first]+=it->second;      
  
    }          
 
    clr(u,0);     
}
void dfs(int u) 
{
    calc(u);      
 
    vis[u]=1;     

    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i];   
 
        if(vis[v])    continue;     
 
        root=0,sn=size[v],getroot(v,u),dfs(root);   
    }
}
int main() 
{ 
    // setIO("input");   
 
    int i,j;        
 
    scanf("%d",&n);     
 
    for(i=1;i<=n;++i)     scanf("%d",&val[i]), A[i]=val[i]; 
 
    sort(A+1,A+1+n); 
 
    for(i=1;i<=n;++i)     val[i]=lower_bound(A+1,A+1+n,val[i])-A;     
 
    for(i=1;i<n;++i) 
    {
        int u,v; 
 
        scanf("%d%d",&u,&v),add(u,v),add(v,u);  
    }          
 
    sn=mx[0]=n,root=0,getroot(1,0),dfs(root);   
 
    for(i=1;i<=n;++i)     printf("%lld\n",Sum[i]);        
 
    return 0; 
}

  

posted @ 2019-12-06 20:14  EM-LGH  阅读(172)  评论(0编辑  收藏  举报