[HNOI2014] 世界树

Description

给定一棵 \(n\) 个点的无根树,有 \(q\) 个询问,每次询问将一些点设置为关键点,要求求出每个关键点管辖了多少个节点。\(a\) 管辖 \(b\) 当且仅当 \(a\) 是距离 \(b\) 最近的关键点或是最近的关键点中编号最小的那一个。\(n\leq 3\cdot 10^5,q\leq 3\cdot 10^5,\sum len\leq 3\cdot 10^5\)

Sol

看见这个 \(\sum len\) 那就肯定是要在虚树上面乱搞了。

首先建出虚树,然后可以通过换根\(\mathrm{dfs}\)求出虚树上每个点被哪个关键点管辖。

然后就要统计不在虚树上的那些点对答案的贡献了。

不在虚树上的点分为两部分,一是 在虚树上的某条边中,二是 在虚树上的某个点的某个子树中(这个子树是原树里的子树)

那就可以愉快的统计了。

记录 \(sze[i]\) 表示原树中子树 \(i\) 的大小,\(siz[i]\) 表示 \(\sum\limits_{(i,j)\text{not in tree}}sze[j]\),即所有没有在虚树中出现的点 \(i\) 的儿子的子树和。可以注意到,管辖 \(siz[i]\) 这些点的关键点一定管辖点 \(i\)

然后就要求,在虚树上某条边中的未出现的点的贡献了。

对于一条虚树上的边 \((x,y)\),首先找到 \(x\) 沿着这条链的第一个孩子 \(s\) ,如果管辖两端点的关键点一样,那直接将 \(sze[s]-sze[y]\) 加进该关键点的答案内,表示这条链上所有的点以及所有点的子树都会被相同的关键点管辖。否则,就要在链上二分出来一个 \(mid\),表示 \(mid\) 以及向下的点都被管辖 \(y\) 的关键点覆盖,\(mid\) 向上的点都被管辖 \(x\) 的关键点覆盖。两个答案分别加上 \(sze[s]-sze[mid],sze[mid]-sze[y]\) 即可。

Code

先记录一下如何建虚树,就是维护一个单调栈,栈内元素深度单调递增,也就是维护了虚树的一条链。然后每次新加入点的时候判断一下,如果这条链走到了头就得往外弹栈,具体看代码,比较好理解。

void ins(int x){
    if(top<=1) return stk[++top]=x,void();
    int LCA=lca(stk[top],x);
    if(LCA==stk[top]) return stk[++top]=x,void();
    while(top>1 and dfn[stk[top-1]]>=dfn[LCA])
        add(stk[top-1],stk[top]),top--;
    if(LCA!=stk[top]) add(LCA,stk[top]),stk[top]=LCA;
    stk[++top]=x;
}

然后是这道题的代码。

#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;
const int N=3e5+5;

pii bel[N];
int is[N],stk[N],top,tot,f[N][20];
int lg[N],d[N],sze[N],siz[N],ans[N];
int n,m,a[N],cnt,dfn[N],head[N],b[N];

struct Edge{
    int to,nxt;
}edge[N<<1];

bool cmp(int x,int y){
    return dfn[x]<dfn[y];
}

void add(int x,int y){
    edge[++cnt].to=y;
    edge[cnt].nxt=head[x];
    head[x]=cnt;
}

void dfs(int now){
    sze[now]=1; dfn[now]=++tot;
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(sze[to]) continue;
        f[to][0]=now;d[to]=d[now]+1;
        for(int j=1;j<=lg[d[to]];j++)
            f[to][j]=f[f[to][j-1]][j-1];
        dfs(to); sze[now]+=sze[to];
    }
}

int lca(int x,int y){
    if(d[x]<d[y]) swap(x,y);
    for(int j=lg[d[x]];~j;j--)
        if(d[f[x][j]]>=d[y]) x=f[x][j];
    if(x==y) return x;
    for(int j=lg[d[x]];~j;j--)
        if(f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
    return f[x][0];
}

void ins(int x){
    if(top<=1) return stk[++top]=x,void();
    int LCA=lca(stk[top],x);
    if(LCA==stk[top]) return stk[++top]=x,void();
    while(top>1 and dfn[stk[top-1]]>=dfn[LCA])
        add(stk[top-1],stk[top]),top--;
    if(LCA!=stk[top]) add(LCA,stk[top]),stk[top]=LCA;
    stk[++top]=x;
}

void dfs1(int now){
    if(is[now]==m) bel[now]=pii(0,now);
    else bel[now]=pii(1e9,0);
    siz[now]=sze[now]-1;
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        dfs1(to);
        bel[now]=min(bel[now],pii(bel[to].first+d[to]-d[now],bel[to].second));
    }
}

void dfs2(int now,int fa){
    if(now!=1) 
        bel[now]=min(bel[now],pii(bel[fa].first+d[now]-d[fa],bel[fa].second));
    ans[bel[now].second]++;
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        dfs2(to,now);
    }
}

int dis(int x,int y){
    return d[x]+d[y]-2*d[lca(x,y)];
}

void dfs3(int now){
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        dfs3(to);
        int son=to,mid=to;
        for(int j=lg[d[son]];~j;j--)
            if(d[f[son][j]]>d[now]) son=f[son][j];
        siz[now]-=sze[son];
        if(bel[now].second==bel[to].second){ans[bel[now].second]+=sze[son]-sze[to];continue;}
        for(int j=lg[d[mid]];~j;j--){
            int p=f[mid][j];
            if(d[p]<=d[now]) continue;
            if(pii(dis(p,bel[to].second),bel[to].second)<pii(dis(p,bel[now].second),bel[now].second))
                mid=p;
        }
        ans[bel[now].second]+=sze[son]-sze[mid];
        ans[bel[to].second]+=sze[mid]-sze[to];
    }
    ans[bel[now].second]+=siz[now];
    head[now]=0;
}

void work(int len=0){
    scanf("%d",&len); cnt=0;
    for(int i=1;i<=len;i++)
        scanf("%d",&a[i]),b[i]=a[i],is[a[i]]=m,ans[a[i]]=0;
    std::sort(a+1,a+1+len,cmp);
    top=0; if(is[1]!=m) stk[++top]=1;
    for(int i=1;i<=len;i++) ins(a[i]);
    while(top>1) add(stk[top-1],stk[top]),top--;
    dfs1(1),dfs2(1,0),dfs3(1);
    for(int i=1;i<=len;i++) printf("%d ",ans[b[i]]);puts("");
}

signed main(){
    scanf("%d",&n);
    for(int x,y,i=1;i<n;i++)
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    for(int i=2;i<=n;i++) lg[i]=lg[i-1]+((1<<lg[i-1]+1)==i);
    d[1]=1; dfs(1); cnt=0; memset(head,0,sizeof head);
    for(scanf("%d",&m);m;work(),m--);
    return 0;
}

posted @ 2019-02-15 14:42  YoungNeal  阅读(209)  评论(0编辑  收藏  举报