虚树【学习笔记】
为什么要用虚树?
例题
在某些树上问题中,对于某次询问,我们并不需要用到全部的树上的点:
例如,例题中:
总点数 \(n \le 2.5\times10^5\)
询问次数 \(m \le 5\times10^5\)
询问的点数 \(\sum k_i \le 5\times10^5\)
我们可以发现其实每次询问均摊下来的询问点数k并不多,但如果每次询问都用到全部的点,会超时
所以我们将所有的关键点拎出来建树,来确保时间复杂度的优秀
朴素做法
我们回到例题上来,可以想到如果树的点数很少时,我们可以直接用 \(DP\) :
首先我们设某次询问中被选中的点(
资源丰富)为 关键点
\(dp_i\) 表示不让 \(i\) 与 \(i\) 的子树内任意一个关键点互通所需要的最小代价
\(w_{u,v}\) 表示连接 \(u\) 和 \(v\) 的边权
\(u\) 表示 \(i\) 连接的一个儿子节点
转移方程式:
- 当 \(u\) 是关键点时 : 你必须砍掉 \(i\) 到 \(u\) 的这条边
\(dp_i+=w_{i,u}\)
- 当 \(u\) 不是关键点时 :你可以选择砍掉 \(i\) 到 \(u\) 的这条边或者让 \(u\) 不连接关键点
\(dp_i+=min(w_{i,u},dp_u)\)
此时时间复杂度为 \(O(nq)\) 肯定过不了,考虑用虚树建一颗更简洁的树(没有那么多用不到的点)
虚树做法
在原树中,我们可以发现大多数点是没用的,以下图为例:
如果我们选取的关键点是2,4:
图中只有两个红色的点是关键点,而别的点全都是非关键点,对于这道题来说,我们只需要保证 1 号节点无法到达2,4就行了而 1 号节点的右子树没有一个关键点,我们没必要去DP它
观察题目给出的条件,红色点(关键点)的总数是与 n 同阶的,也就是说实际上一次询问中红色的点对于整棵树来说是很稀疏的,所以如果我们能让复杂度由关键点的总数来决定就好了
所以我们需要浓缩信息,只存储与答案相关的信息,把一整颗大树浓缩成小树
虚树长什么样?
这里我们主要通过一些图理解(感谢oiwiki我才不用画图)
下图中,红色结点是我们选择的关键点,红色和黑色结点都是虚树中的点(要把某些红色节点相连必须用到黑色节点),黑色的边是虚树中的边。
因为任意两个关键点的 \(lca\) 也是需要保存重要信息的,所以我们需要保存它们的 \(lca\),因此虚树中不一定只有关键点。
怎么构造虚树?
这里介绍的是二次排序+ lca 连边的方法,还有一种单调栈的构造方法,详见oiwiki
- 将关键节点按照 \(dfs\)序 排序,并插入序列 \(a\)
- 关键节点中两两求 \(lca\),插入序列 \(a\) 中
- 在将序列 \(a\) 按照 \(dfs\)序 排序,并去重
- 遍历序列 \(a\) ,枚举相邻两个点的编号(设为 \(x\) ,\(y\) )求 \(lca\) ,建一条由 \(lca\) 指向 \(y\) 的边
为什么连接 \(LCA(x,y)\) 和 \(y\) 可以做到不重不漏呢?
证明:
如果 \(x\) 是 \(y\) 的祖先,那么 \(x\) 直接到 \(y\) 连边。因为 \(dfs\)序保证了 \(x\) 和 \(y\) 的 \(dfs\)序是相邻的,所以 \(x\) 到 \(y\) 的路径上面没有关键点。
如果 \(x\) 不是 \(y\) 的祖先,那么就把 \(lca(x,y)\) 当作 \(y\) 的的祖先,根据上一种情况也可以证明 \(lca(x,y)\) 到 \(y\) 点的路径上不会有关键点。
所以连接 \(lca(x,y)\) 和 \(y\),不会遗漏,也不会重复。
另外第一个点没有被一个节点连接会不会有影响呢?因为第一个点一定是这棵树的根,所以不会有影响,所以总边数就是 \(m-1\) 条。
因为至少要两个实点才能够召唤出来一个虚点,再加上一个根节点,所以虚树的点数就是实点数量的两倍。
时间复杂度 \(O(klog_n)\),其中 \(k\) 为关键点数,\(n\) 为总点数。
实现:
int dfn[maxn]
int h[maxn], a[maxn], cnt; // 存储关键点
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void buid{
h[++k]=1;//为了方便,我们首先将1号节点加入虚树中
sort(h+1,h+1+k,cmp);//操作1,按照dfs序排序
for (int i=1; i<=k; i++) {
a[++cnt]=h[i];//将关键点插入序列a
if (i==k) break;
//操作2,两两求lca插入序列a中
a[++cnt]=lca(h[i],h[i+1]);
}
sort(a+1,a+1+cnt,cmp);//操作3,排序
cnt=unique(a+1,a+1+cnt)-(a+1);//去重
for (int i=1; i<cnt; i++) {
int lc=lca(a[i],a[i+1]);
add(lc,a[i+1],0);//操作4,连一条由lca(x,y)指向y的边
}
}
回到例题
虚树建好后,这道题就很好攻克了
设 \(miv_i\) 表示 \(i\) 到 1 号节点边权最小的一条边(容易理解的是:割掉这条边后,\(i\) 就不再与 1 号节点相连了)
\(col_i\)记录 \(i\) 是否为关键点(是关键点为1,否则为0)
- \(miv\) 和一些其他数组的预处理
void dfs1(int x,int fa){
vis[x]=1;
dfn[x]=++cnt;
for (int i=he[x];i;i=ne[i])
if (!vis[to[i]]){
d[to[i]]=d[x]+1;
f[to[i]][0]=x;
miv[to[i]]=min(miv[x],w[i]);
dfs1(to[i],x);
}
he[x]=0;
}
- 求解让 1 号节点不与( \(x\) 及 \(x\) 的子树中的关键点)连通的最小代价
int dfs2(int x,int fa){
int tmp=0,ans;
for (int i=he[x];i;i=ne[i])
tmp+=dfs2(to[i],x);
if (col[x]) ans=miv[x];
else ans=min(miv[x],tmp);
he[x]=0;col[x]=0;
//多次询问,可以在递归中直接清空
return ans;
}
完整代码(*╹▽╹*)
#include<bits/stdc++.h>
#define int long long
#define pai pair<int,int>
#define mk make_pair
#define fi first
#define se second
using namespace std;
const int maxn=1e6+10;
const int N=30;
const int INF=1e18;
int read(){
int x=0,f=1;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') {x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return x*f;
}
int tot,n,q,cnt,k,d[maxn];
int miv[maxn],dfn[maxn];
int he[maxn],w[maxn<<1];
int ne[maxn<<1],to[maxn<<1];
int h[maxn],a[maxn];
int f[maxn][N];
bool col[maxn],vis[maxn];
void add(int u,int v,int z){
ne[++tot]=he[u];
he[u]=tot;
to[tot]=v;
w[tot]=z;
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void dfs1(int x,int fa){
vis[x]=1;
dfn[x]=++cnt;
for (int i=he[x];i;i=ne[i])
if (!vis[to[i]]){
d[to[i]]=d[x]+1;
f[to[i]][0]=x;
miv[to[i]]=min(miv[x],w[i]);
dfs1(to[i],x);
}
he[x]=0;
}
void init(){
for (int j=1;j<=20;j++)
for (int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
}
int lca(int x,int y){
if (x==y) return x;
if (d[x]<d[y]) swap(x,y);
for (int j=log2(d[x]);j>=0;j--)
if (d[f[x][j]]>=d[y])
x=f[x][j];
if (x==y) return x;
for (int j=log2(d[x]);j>=0;j--)
if (f[x][j]!=f[y][j])
x=f[x][j],y=f[y][j];
return f[x][0];
}
int dfs2(int x,int fa){
int tmp=0,ans;
for (int i=he[x];i;i=ne[i])
tmp+=dfs2(to[i],x);
if (col[x]) ans=miv[x];
else ans=min(miv[x],tmp);
he[x]=0;col[x]=0;
return ans;
}
signed main(){
n=read();
for (int i=1,x,y,z;i<n;i++){
x=read();y=read();z=read();
add(x,y,z);add(y,x,z);
}
d[0]=-INF,miv[1]=INF;
dfs1(1,0);init();q=read();
while (q--){
k=read();tot=cnt=0;
for (int i=1,x;i<=k;i++){
x=read();
h[i]=x;
col[x]=1;
}
h[++k]=1;
sort(h+1,h+1+k,cmp);
for (int i=1;i<=k;i++){
a[++cnt]=h[i];
if (i==k) break;
a[++cnt]=lca(h[i],h[i+1]);
}
sort(a+1,a+1+cnt,cmp);
cnt=unique(a+1,a+1+cnt)-(a+1);
for (int i=1;i<cnt;i++){
int lc=lca(a[i],a[i+1]);
add(lc,a[i+1],0);
}
printf("%lld\n",dfs2(1,0));
for (int i=1;i<=cnt;i++)
he[i]=0,col[i]=0;
}
return 0;
}