[HNOI2014]世界树

link

题目大意

有一颗树,且每次询问包含$k$个特殊节点,问每个特殊节点能占领几个节点

试题分析

到底是道德的沦丧,还是人性的扭曲。

其实发现数据范围$\sum_{i=1}^q m_i \leq 3\times 10^5$想到了什么---虚树,所以应该怎么用虚树优化。

先将虚树建出来,这是基本操作,点为$(u,v)$,则边权为$dis_{u,v}$

然后就$dfs$两遍,一遍求儿子对父亲的占领影响,然后第二遍求父亲对儿子的占领影响。记录$dis_i$为从$i$到特殊节点的最小距离,$bel_i$为$i$到哪一个特殊节点。

然后怎样去统计答案,我们发现若树边$(u,v)$,他们的$bel$相等,则就让父亲减去儿子的子树大小,就是不算重了。若不等,则我们可以计算出两端,两端值两个特殊节点的节点数。然后经过一系列的计算便可以倍增求出两个的中断值。然后在加减一下子树就行了。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<stack>
#include<algorithm>
#include<climits>
using namespace std;
inline int read(){
    int f=1,ans=0;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return f*ans;
}
const int N=300001;
struct node{
    int u,v,w,nex;
}x[N<<1];
int in[N],size[N],out[N],num,n,head[N],cnt,deep[N],fa[N][21];
void add(int u,int v,int w){
    x[cnt].u=u,x[cnt].v=v,x[cnt].w=w,x[cnt].nex=head[u],head[u]=cnt++;
}
int remove(int u,int Deep){
    for(int i=20;i>=0;i--) 
       if(deep[u]-(1<<i)>=Deep) u=fa[u][i];
     return u; 
}
void dfs(int f,int fath){
    in[f]=++num;deep[f]=deep[fath]+1,size[f]=1;fa[f][0]=fath;
    for(int i=1;(1<<i)<=deep[f];i++) fa[f][i]=fa[fa[f][i-1]][i-1];
    for(int i=head[f];i!=-1;i=x[i].nex){
        if(x[i].v==fath) continue;
        dfs(x[i].v,f);
        size[f]+=size[x[i].v];
    }out[f]=++num;
    return;
}
stack<int> s;
int que[N];
int lca(int u,int v){
    if(deep[u]<deep[v]) swap(u,v);
    for(int i=20;i>=0;i--)
        if(deep[u]-(1<<i)>=deep[v]) u=fa[u][i];
    if(u==v) return u;
    for(int i=20;i>=0;i--){
        if(fa[u][i]==fa[v][i]) continue;
        u=fa[u][i],v=fa[v][i];
    }return fa[u][0];
}
int Q,book[N],sta[N],st;
int Dis(int u,int v){return deep[u]+deep[v]-2*deep[lca(u,v)];}
bool cmp(int x,int y){
    int s1,s2;
    if(x>0) s1=in[x];else s1=out[-x];
    if(y>0) s2=in[y];else s2=out[-y];
    return s1<s2;
}
int eff[N];
void build(){
    memset(head,-1,sizeof(head)),cnt=0;
    sort(sta+1,sta+st+1,cmp);
    for(int i=1;i<st;i++){
        int Lca=lca(sta[i],sta[i+1]);
        if(!book[Lca]){book[Lca]=1;sta[++st]=Lca;}
    }
    int now=st;
    for(int i=1;i<=now;i++) sta[++st]=-sta[i];
    if(!book[1]) sta[++st]=1,sta[++st]=-1;
    sort(sta+1,sta+st+1,cmp);
    for(int i=1;i<=st;i++){
        if(sta[i]>0) s.push(sta[i]);
        else{
            int f=s.top();s.pop();
            if(f!=1){int fath=s.top();int dis=Dis(f,fath);add(f,fath,dis),add(fath,f,dis);}
        }
    }return;
}
int dis[N],bel[N],S[N];
void dfs1(int f,int fath){
    que[++que[0]]=f;
    if(eff[f]) dis[f]=0,bel[f]=f;
    else dis[f]=INT_MAX;
    S[f]=size[f];
    for(int i=head[f];i!=-1;i=x[i].nex){
        if(x[i].v==fath) continue;
        dfs1(x[i].v,f);
        if((dis[f]>dis[x[i].v]+x[i].w)||((dis[f]==dis[x[i].v]+x[i].w)&&(bel[f]>bel[x[i].v]))){
            dis[f]=dis[x[i].v]+x[i].w;
            bel[f]=bel[x[i].v];
        }
    }return;
}
int ans[N];
void dfs2(int f,int fath){
    for(int i=head[f];i!=-1;i=x[i].nex){
        if(x[i].v==fath) continue;
        if((dis[f]+x[i].w<dis[x[i].v])||((dis[f]+x[i].w==dis[x[i].v])&&bel[x[i].v]>bel[f])){
            dis[x[i].v]=x[i].w+dis[f];
            bel[x[i].v]=bel[f];
        }
        dfs2(x[i].v,f);
        if(bel[f]==bel[x[i].v]) S[f]-=size[x[i].v];
        else{
            int DIS=(dis[x[i].v]+deep[x[i].v])+(dis[f]-deep[f])-1;
            int k=DIS/2-dis[x[i].v];
            int tmp=remove(x[i].v,deep[x[i].v]-k);
            if((DIS%2)&&(bel[x[i].v]<bel[f])&&(k>=0)) tmp=fa[tmp][0];
            S[x[i].v]+=size[tmp]-size[x[i].v];
            S[f]-=size[tmp];
        }
        ans[bel[x[i].v]]+=S[x[i].v];
    }
    if(f==1) ans[bel[1]]+=S[1];
}
int xs[N];
int main(){
    memset(head,-1,sizeof(head));
    n=read();
    for(int i=1;i<n;i++){
        int u=read(),v=read();
        add(u,v,0),add(v,u,0);
    }
    dfs(1,0);
    Q=read();
    while(Q--){
        st=read();
        for(int i=1;i<=st;i++) sta[i]=read(),book[sta[i]]=1,eff[sta[i]]=1,xs[i]=sta[i];
        int ST=st;
        build();
        dfs1(1,0),dfs2(1,0);
        for(int i=1;i<=ST;i++) printf("%d ",ans[xs[i]]);
        for(int i=1;i<=que[0];i++) book[que[i]]=eff[que[i]]=dis[que[i]]=bel[que[i]]=S[que[i]]=ans[que[i]]=0;que[0]=0;
        printf("\n");
    }
    return 0;
}
View Code

 

posted @ 2018-12-10 22:16  siruiyang_sry  阅读(157)  评论(0编辑  收藏  举报