Luogu P3233 [HNOI2014]世界树
虚树 + DP
想了很久很久还调了半天
就算有倍增数组我也要树剖求lca
先建出虚树,然后虚树上的每个节点维护到自己最近的关键点的距离 \(dis[u]\) 和所属的关键点 \(be[u]\) 。
dp1(int u)
从底往上去更新, dp2(int u)
从上向下去更新。
计算答案时,我们发现若虚树上一条边 \((u,v)\)(即原图中的一条链 \((u,\cdots,v)\) )满足 \(be[u]\neq be[v]\) ,那么 \((u,\cdots,v)\) 中一定被分为了不相交的两部分,即
\(be[i]=be[u],i\in(u,\cdots,x)\\be[j]=be[v],j\in (y,\cdots,v)\\(u,\cdots,x)+(y,\cdots,v)=(u,\cdots,v)\)
为了快速获得分界点,我们在原树上倍增去找即可;记得排除同一棵子树的贡献,即在父节点时要减掉子树的 \(sz\) 。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define R register int
using namespace std;
namespace Luitaryi {
inline int g() { R x=0,f=1;
register char s; while(!isdigit(s=getchar())) f=s=='-'?-1:f;
do x=x*10+(s^48); while(isdigit(s=getchar())); return x*f;
} const int N=300010,B=18,Inf=0x3f3f3f3f;
int n,m,k;
int vr[N<<1],nxt[N<<1],fir[N],w[N<<1],cnt,num;
int pre[N],d[N],sz[N],son[N],dfn[N],top[N],fa[N][B+1],h[N],stk[N],siz[N],ans[N];
int dis[N],be[N],mem[N]; bool vis[N];
inline void add(int u,int v,int ww)
{vr[++cnt]=v,nxt[cnt]=fir[u],fir[u]=cnt,w[cnt]=ww;}
inline void dfs(int u) {
dfn[u]=++num,sz[u]=1;
for(R t=1;t<=B;++t)
fa[u][t]=fa[fa[u][t-1]][t-1];
for(R i=fir[u];i;i=nxt[i]) {
R v=vr[i];
if(dfn[v]) continue;
pre[v]=u,d[v]=d[u]+1;
fa[v][0]=u;
dfs(v);
sz[u]+=sz[v];
if(sz[son[u]]<sz[v]) son[u]=v;
}
}
inline void dfs2(int u,int tp) {
top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(R i=fir[u];i;i=nxt[i]) {
R v=vr[i];
if(!top[v]) dfs2(v,v);
}
}
inline bool cmp(const int& a,const int& b)
{return dfn[a]<dfn[b];}
inline int lca(int u,int v) {
while(top[u]!=top[v]) {
if(d[top[u]]<d[top[v]]) swap(u,v);
u=pre[top[u]];
} return d[u]<d[v]?u:v;
}
inline int dist(int u,int v) {return d[u]+d[v]-2*d[lca(u,v)];}
inline int jmp(int u,int dep) {
for(R t=B;~t;--t) if(d[fa[u][t]]>=dep) u=fa[u][t];
return u;
}
inline void dp1(int u) {
siz[u]=sz[u];
if(vis[u]) dis[u]=0,be[u]=u;
else dis[u]=Inf,be[u]=0;
for(R i=fir[u];i;i=nxt[i]) {
R v=vr[i];
dp1(v);
if(dis[u]>dis[v]+w[i]||(dis[u]==dis[v]+w[i]&&be[u]>be[v]))
dis[u]=dis[v]+w[i],be[u]=be[v];
}
}
inline void dp2(int u) {
for(R i=fir[u];i;i=nxt[i]) {
R v=vr[i];
if(dis[v]>dis[u]+w[i]||(dis[v]==dis[u]+w[i]&&be[u]<be[v]))
dis[v]=dis[u]+w[i],be[v]=be[u];
dp2(v);
if(be[u]==be[v]) siz[u]-=sz[v];
else {
R t=dis[u]+(d[v]-d[u])-dis[v];
R tmp=jmp(v,d[v]-(t-1)/2);
if(t&&t%2==0&&be[v]<be[u]) tmp=pre[tmp];
siz[v]+=sz[tmp]-sz[v];
siz[u]-=sz[tmp];
}
ans[be[v]]+=siz[v];
}
}
inline void main() {
n=g();
for(R i=1,u,v;i<n;++i)
u=g(),v=g(),add(u,v,0),add(v,u,0);
d[1]=1,dfs(1),dfs2(1,1);
memset(fir,0,sizeof fir);
m=g();
while(m--) {
k=g();
for(R i=1;i<=k;++i) vis[mem[i]=h[i]=g()]=true;
sort(h+1,h+k+1,cmp);
fir[1]=cnt=0; R top;
stk[top=1]=1;
for(R i=1+(h[1]==1),l;i<=k;++i) {
l=lca(h[i],stk[top]);
if(l!=stk[top]) {
while(dfn[l]<dfn[stk[top-1]])
add(stk[top-1],stk[top],dist(stk[top],stk[top-1])),--top;
if(dfn[l]>dfn[stk[top-1]])
fir[l]=0,
add(l,stk[top],dist(stk[top],l)),stk[top]=l;
else
add(l,stk[top],dist(stk[top],l)),--top;
}
fir[h[i]]=0,stk[++top]=h[i];
}
for(R i=1;i<top;++i)
add(stk[i],stk[i+1],dist(stk[i+1],stk[i]));
dp1(1),dp2(1); ans[be[1]]+=siz[1];
for(R i=1;i<=k;++i) printf("%d ",ans[mem[i]]),ans[mem[i]]=0; puts("");
for(R i=1;i<=k;++i) vis[h[i]]=0;
}
}
} signed main() {Luitaryi::main(); return 0;}
2020.01.17