虚树
虚树
我们有这样一类问题,对于一个有 \(n\) 个点的树,有 \(m\) 组询问,每次询问给出 \(k\) 个关键点,关键点之间的最短距离。
其中 \(\sum k\le n,n\le 1e5\)。
我们发现这一类问题如果我们对于每一组询问都跑一遍 \(O(n)\) 的 \(DP\),显然复杂度是不对的,但我们发现有用的点的和为 \(O(n)\),那就启示我们将关键点单独处理,就诞生了虚树。
我们将关键点与他们的 \(LCA\) 建成一棵树,我们就可以在这棵树上跑 \(DP\) 了,这样总复杂度就为 \(O(n)\) 了。
对于如何建虚树,我们可以使用一个栈来维护当前所建虚树的右链,具体建法如下。
首先我们求出栈顶节点与当前所加节点的 \(LCA\),如果 \(LCA\) 为栈顶节点,那么我们不做任何操作,将当前所加节点入栈,如果不等,有这样一些情况。
如果我们的栈中的次栈顶元素的 \(dfn\) 序大于 \(LCA\) 的 \(dfn\) 序,我们就将次栈顶元素与栈顶元素连边,并将栈顶出栈,直到次栈顶元素的 \(dfn\) 序小于等于 \(LCA\) 的 \(dfn\) 序为止。
当我们的次栈顶元素的 \(dfn\) 序小于等于 \(LCA\) 的 \(dfn\) 序时,我们分两种情况讨论:
\(1.\) 我们的次栈顶元素的 \(dfn\) 序等于 \(LCA\) 的 \(dfn\) 序,说明当前所维护的右链不如此时新加的节点更右,我们就将 \(LCA\) 与栈顶元素连边,并将栈顶出栈即可。
\(2.\) 我们的次栈顶元素的 \(dfn\) 序小于 \(LCA\) 的 \(dfn\) 序,
说明当前所维护的右链不如此时新加的节点更右,但是 \(LCA\) 必须作为一个新节点加入虚树,所以我们将 \(LCA\) 与栈顶元素连边,并将栈顶元素换为 \(LCA\)。
最后我们将所加节点入栈。
当我们将所有关键点都加入过栈后,如果栈不空,将栈中元素按顺序连边即可建好虚树。
对于要建虚树的题,往往建树不是难点,而难点在于建好树后的 \(DP\)
\(code\)
#include<bits/stdc++.h>
using namespace std;
const int N=2e6+10;
const long long INF=0x3f3f3f3f3f3f3f3f;
int n,k;
long long ans1,ans2,ans3;
struct Graph{
int head[N],tot,dfn[N],dep[N],size[N],top[N],fa[N],son[N],ntime;
struct Node{int to,nest;}bian[N<<1];
void add(int x,int y){bian[++tot]=(Node){y,head[x]};head[x]=tot;}
void get_heavy_son(int x,int f,int depth){
dep[x]=depth,fa[x]=f,size[x]=1,son[x]=0;
for(int i=head[x];i;i=bian[i].nest){
int v=bian[i].to;
if(v==f)continue;
get_heavy_son(v,x,depth+1);
size[x]+=size[v];
if(size[v]>size[son[x]])son[x]=v;
}
}
void get_heavy_edge(int x,int tp){
top[x]=tp;dfn[x]=++ntime;
if(son[x])get_heavy_edge(son[x],tp);
for(int i=head[x];i;i=bian[i].nest){
int v=bian[i].to;
if(v==son[x]||v==fa[x])continue;
get_heavy_edge(v,v);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
return dep[x]<dep[y]?x:y;
}
}G;
inline bool cmp(int x,int y){return G.dfn[x]<G.dfn[y];}
struct Virtual_Tree{
int head[N],tot,stk[N],top,vis[N],kiss[N],size[N];//kiss -> key_point
long long dp[N],mink[N],maxk[N];
struct Node{int to,nest;}bian[N<<1];
void add(int x,int y){bian[++tot]=(Node){y,head[x]};head[x]=tot;}
void in(){for(int i=1;i<=k;i++)scanf("%d",&kiss[i]),vis[kiss[i]]=1;}
void build(){
sort(kiss+1,kiss+1+k,cmp);
tot=0,top=0;
stk[++top]=1;
for(int i=1;i<=k;i++){
if(kiss[i]==1)continue;
int now=kiss[i];
int lca=G.LCA(now,stk[top]);
if(lca!=stk[top]){
while(G.dfn[lca]<G.dfn[stk[top-1]])add(stk[top-1],stk[top]),top--;
if(G.dfn[lca]!=G.dfn[stk[top-1]])add(lca,stk[top]),stk[top]=lca;
else add(lca,stk[top]),top--;
}
stk[++top]=now;
}
for(int i=top;i>=2;i--)add(stk[i-1],stk[i]);
top=0;
}
void DP(int x){
if(vis[x])size[x]=1,mink[x]=0,maxk[x]=0;
else size[x]=0,mink[x]=INF,maxk[x]=-INF;
vis[x]=0;
dp[x]=0;
for(int &i=head[x];i;i=bian[i].nest){//( & )-> 多测清空
int v=bian[i].to;
DP(v);
long long dis=G.dep[v]-G.dep[x];
ans1+=dp[x]*size[v]+dp[v]*size[x]+dis*size[x]*size[v];
ans2=min(ans2,mink[x]+mink[v]+dis),ans3=max(ans3,maxk[x]+maxk[v]+dis);
dp[x]+=size[v]*dis+dp[v];
size[x]+=size[v];
mink[x]=min(mink[x],mink[v]+dis);
maxk[x]=max(maxk[x],maxk[v]+dis);
}
}
}T;
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int s1=0,s2=0;
scanf("%d%d",&s1,&s2);
G.add(s1,s2),G.add(s2,s1);
}
G.get_heavy_son(1,0,1);
G.get_heavy_edge(1,1);
int Q=0;
scanf("%d",&Q);
for(int i=1;i<=Q;i++){
scanf("%d",&k);
T.in();
T.build();
ans1=0,ans2=INF,ans3=-INF;
T.DP(1);
printf("%lld %lld %lld\n",ans1,ans2,ans3);
}
return 0;
}