Luogu P3233 [HNOI2014]世界树

虚树 + DP

想了很久很久还调了半天

就算有倍增数组我也要树剖求lca

先建出虚树,然后虚树上的每个节点维护到自己最近的关键点的距离 \(dis[u]\) 和所属的关键点 \(be[u]\)

dp1(int u) 从底往上去更新, dp2(int u) 从上向下去更新。

计算答案时,我们发现若虚树上一条边 \((u,v)\)(即原图中的一条链 \((u,\cdots,v)\) )满足 \(be[u]\neq be[v]\) ,那么 \((u,\cdots,v)\) 中一定被分为了不相交的两部分,即

\(be[i]=be[u],i\in(u,\cdots,x)\\be[j]=be[v],j\in (y,\cdots,v)\\(u,\cdots,x)+(y,\cdots,v)=(u,\cdots,v)\)

为了快速获得分界点,我们在原树上倍增去找即可;记得排除同一棵子树的贡献,即在父节点时要减掉子树的 \(sz\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define R register int
using namespace std;
namespace Luitaryi { 
inline int g() { R x=0,f=1;
  register char s; while(!isdigit(s=getchar())) f=s=='-'?-1:f;
  do x=x*10+(s^48); while(isdigit(s=getchar())); return x*f;
} const int N=300010,B=18,Inf=0x3f3f3f3f;
int n,m,k;
int vr[N<<1],nxt[N<<1],fir[N],w[N<<1],cnt,num;
int pre[N],d[N],sz[N],son[N],dfn[N],top[N],fa[N][B+1],h[N],stk[N],siz[N],ans[N];
int dis[N],be[N],mem[N]; bool vis[N];
inline void add(int u,int v,int ww) 
  {vr[++cnt]=v,nxt[cnt]=fir[u],fir[u]=cnt,w[cnt]=ww;}
inline void dfs(int u) {
  dfn[u]=++num,sz[u]=1;
  for(R t=1;t<=B;++t) 
    fa[u][t]=fa[fa[u][t-1]][t-1];
  for(R i=fir[u];i;i=nxt[i]) {
    R v=vr[i];
    if(dfn[v]) continue;
    pre[v]=u,d[v]=d[u]+1;
    fa[v][0]=u;
    dfs(v);
    sz[u]+=sz[v];
    if(sz[son[u]]<sz[v]) son[u]=v;
  }
}
inline void dfs2(int u,int tp) {
  top[u]=tp;
  if(son[u]) dfs2(son[u],tp);
  for(R i=fir[u];i;i=nxt[i]) {
    R v=vr[i];
    if(!top[v]) dfs2(v,v);
  }
}
inline bool cmp(const int& a,const int& b) 
  {return dfn[a]<dfn[b];}
inline int lca(int u,int v) {
  while(top[u]!=top[v]) {
    if(d[top[u]]<d[top[v]]) swap(u,v);
    u=pre[top[u]];
  } return d[u]<d[v]?u:v;
}
inline int dist(int u,int v) {return d[u]+d[v]-2*d[lca(u,v)];}
inline int jmp(int u,int dep) {
  for(R t=B;~t;--t) if(d[fa[u][t]]>=dep) u=fa[u][t];
  return u;
}
inline void dp1(int u) {
  siz[u]=sz[u];
  if(vis[u]) dis[u]=0,be[u]=u;
  else dis[u]=Inf,be[u]=0;
  for(R i=fir[u];i;i=nxt[i]) {
    R v=vr[i];
    dp1(v);
    if(dis[u]>dis[v]+w[i]||(dis[u]==dis[v]+w[i]&&be[u]>be[v]))
      dis[u]=dis[v]+w[i],be[u]=be[v];
  }
} 
inline void dp2(int u) {
  for(R i=fir[u];i;i=nxt[i]) {
    R v=vr[i];
    if(dis[v]>dis[u]+w[i]||(dis[v]==dis[u]+w[i]&&be[u]<be[v]))
      dis[v]=dis[u]+w[i],be[v]=be[u];
    dp2(v);
    if(be[u]==be[v]) siz[u]-=sz[v];
    else {
      R t=dis[u]+(d[v]-d[u])-dis[v];
      R tmp=jmp(v,d[v]-(t-1)/2);
      if(t&&t%2==0&&be[v]<be[u]) tmp=pre[tmp];
      siz[v]+=sz[tmp]-sz[v];
      siz[u]-=sz[tmp];
    }
    ans[be[v]]+=siz[v];
  } 
}
inline void main() {
  n=g();
  for(R i=1,u,v;i<n;++i) 
    u=g(),v=g(),add(u,v,0),add(v,u,0);
  d[1]=1,dfs(1),dfs2(1,1);
  memset(fir,0,sizeof fir);
  m=g();
  while(m--) {
    k=g();
    for(R i=1;i<=k;++i) vis[mem[i]=h[i]=g()]=true;
    sort(h+1,h+k+1,cmp);
    fir[1]=cnt=0; R top;
    stk[top=1]=1;
    for(R i=1+(h[1]==1),l;i<=k;++i) {
      l=lca(h[i],stk[top]);
      if(l!=stk[top]) {
        while(dfn[l]<dfn[stk[top-1]]) 
          add(stk[top-1],stk[top],dist(stk[top],stk[top-1])),--top;
        if(dfn[l]>dfn[stk[top-1]])
          fir[l]=0,
          add(l,stk[top],dist(stk[top],l)),stk[top]=l;
        else 
          add(l,stk[top],dist(stk[top],l)),--top;
      }
      fir[h[i]]=0,stk[++top]=h[i];
    }
    for(R i=1;i<top;++i) 
      add(stk[i],stk[i+1],dist(stk[i+1],stk[i]));
    dp1(1),dp2(1); ans[be[1]]+=siz[1];
    for(R i=1;i<=k;++i) printf("%d ",ans[mem[i]]),ans[mem[i]]=0; puts("");
    for(R i=1;i<=k;++i) vis[h[i]]=0;
  }
}
} signed main() {Luitaryi::main(); return 0;}

2020.01.17

posted @ 2020-01-17 20:39  LuitaryiJack  阅读(172)  评论(0编辑  收藏  举报