BZOJ 2815: [ZJOI2012]灾难 拓扑排序+倍增LCA

这种问题的转化方式挺巧妙的. 

Code: 

#include <bits/stdc++.h>               
#define N 100000    
#define M 1000000 
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;   
queue<int>q;   
vector<int>G,V[N];            
int n,root,edges; 
int fa[20][N],hd[N],to[M],nex[M],deg[N],sum[N],dep[N];                
void add(int u,int v) 
{ 
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,++deg[v];     
}   
void toposort() 
{     
    for(int i=1;i<=n+1;++i) if(deg[i]==0) q.push(i), G.push_back(i);        
    for(;!q.empty();) 
    {
        int u=q.front();q.pop();   
        for(int i=hd[u];i;i=nex[i])
        {
            int v=to[i]; 
            --deg[v]; 
            if(deg[v]==0) q.push(v), G.push_back(v);   
        }
    }
}
int LCA(int x,int y) 
{
    if(dep[x]>dep[y]) swap(x,y);      // dep of y is greater than dep of x        
    if(dep[x]!=dep[y]) 
    {
        for(int i=18;i>=0;--i) if(dep[fa[i][y]]>=dep[x]) y=fa[i][y];     
    }   
    if(x==y) return x;   
    for(int i=18;i>=0;--i) 
    {
        if(fa[i][x]!=fa[i][y]) 
        {
            x=fa[i][x],y=fa[i][y];   
        } 
    }   
    return fa[0][x];     
}
void dfs(int u) 
{
    sum[u]=1;  
    for(int i=0;i<V[u].size();++i) 
    {
        int v=V[u][i];    
        dfs(v);   
        sum[u]+=sum[v];   
    }
}
int main() 
{  
    int i,j; 
    // setIO("input");   
    scanf("%d",&n);                              
    root=n+1;     
    for(i=1;i<=n;++i) 
    {
        int t; 
        scanf("%d",&t);
        if(!t) add(i,root);   
        else for(;t!=0;scanf("%d",&t)) add(i,t);      
    }
    toposort();      
    dep[n+1]=1;                 
    for(i=G.size()-2;i>=0;--i) 
    {
        int u=G[i]; 
        int lca=to[hd[u]];       
        for(j=nex[hd[u]];j;j=nex[j]) lca=LCA(lca, to[j]);     
        fa[0][u]=lca;
        dep[u]=dep[lca]+1;                    
        V[lca].push_back(u);       
        for(j=1;j<=18;++j) fa[j][u]=fa[j-1][fa[j-1][u]];    
    }      
    dfs(root);   
    for(i=1;i<=n;++i) printf("%d\n",sum[i]-1);     
    return 0; 
}    

  

posted @ 2019-09-20 18:58  EM-LGH  阅读(162)  评论(0编辑  收藏  举报