虚树 学习笔记

听起来很高级的数据结构,但实际上很好理解。

主要用于处理标记 \(k\) 个特殊点,进行一些询问的问题,一般题中会给出 \(\sum k\) 的数据范围。


1. Idea

虚树其实就是将一棵树的所有特殊点以及特殊点两两之间的 \(\text{lca}\) 构成一棵新的树

具体来说,我们可以预处理出树上所有节点的 \(\text{dfs}\) 序,然后将特殊点按照 \(\text{dfs}\) 序排序,将排序后的每两个相邻的特殊点之间求一个 \(\text{lca}\),并将其加入特殊点的行列。

为什么 \(\text{dfs}\) 序相邻的特殊点的 \(\text{lca}\) 就可以涵盖所有特殊点两两之间的 \(\text{lca}\) 了呢?

考虑对于三个 \(\text{dfs}\) 序分别为 \(a,b,c\) 的节点,其中 \(a<b<c\) 。若 \(b\)\(a\) 的子树内,则 \(b\)\(c\)\(\text{lca}\) 显然也是 \(a\)\(c\)\(\text{lca}\);若 \(b\)\(a\) 的子树外,不妨设 \(c\)\(b\)\(\text{lca}\) 深度小于 \(a\)\(b\)\(\text{lca}\) ,则前者必定也是后者的祖先(画个图很明显),也是 \(a\) 的祖先,若不是 \(a\)\(c\)\(\text{lca}\),则一定也不是 \(b\)\(c\)\(\text{lca}\),故这两个 \(\text{lca}\) 一定也包含 \(a\)\(c\)\(\text{lca}\)。故推广后可知,必定涵盖。(怎么感觉像绕口令)

综上所述,记住就可以啦!(大雾)

处理完之后,我们再将记录特殊点的数组排序一次,去重后考虑建树。

由于特殊点是按照 \(\text{dfs}\) 序排序的,所以其实存储特殊点的数组也是按照新树的 \(\text{dfs}\) 序排序的。那么建树的问题其实就转化为,给定一棵树的 \(\text{dfs}\) 序,建出这棵树。这个问题的解决方式自然很简单,用一个类似于单调栈的栈维护当前节点的祖先,每次弹出栈顶不符合条件的祖先后栈顶的祖先即为该节点在新树中的父亲,然后将该节点加入栈顶即可。

建好树之后,我们就可以在树上进行 dp 等操作处理答案了。

特殊点一共 \(k\) 个,扩展完 \(\text{lca}\) 后最多 \(2k-1\) 个,故 dp 等操作的复杂度可以被成功地降到 \(O(\sum k)\) 相关,而由于建树时需要获得 \(k-1\)\(\text{lca}\),故所有询问的建树总复杂度为 \(O(\sum k\times\log n)\)。且为在线算法

需要注意新树中的根节点并不一定为原树中的根节点,一般情况下,新树中的根节点是所有特殊点当中 \(\text{dfs}\) 序最小的节点。


2. Example

2.1 CF613D Kingdom and its Cities

给定一棵树,每次询问给定 \(k\) 个特殊点,找出尽量少的非特殊点使得删去这些点后特殊点两两不连通。\(\sum k\le n.\)

观察到 选点+限制所有询问总点数,我们基本就可以确定这是一道虚树的题目。

先建树,如果两个特殊点的距离为 \(1\),则无解。

否则对于一个节点,如果是特殊点,则需要断开其所有需要断开的儿子,并向上传递一个需要被断开的标记;如果不是特殊点,且儿子中需要被断开的标记个数大于等于 \(2\),则断开当前节点,如果个数为 \(1\),则向上传递需要被断开的标记,个数为 \(0\) 则不传。最后的答案即为过程中一共被断开的点数。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 100005
int n,m,k,jtot,dtot,ltot,st[N<<1],stot,ans;
bool can[N],flag;
struct node{
	int to,next;
	node (int to=0,int next=0)
		:to(to),next(next){}
};
struct node2{
	int head[N],tot;
	node e[N<<1];
	void adde(int u,int v){
		e[++tot]=node(v,head[u]);
		head[u]=tot;
	}
}S,T;
struct node1{
	int fa,tp,zson,size,dep;
}e[N];
struct node3{
	int dfn,low,id;
}p[N],jl[N<<1];
int read(){
	int wh=0,fh=1;
	char c=getchar();
	while (c>'9'||c<'0'){
		if (c=='-') fh=-1;
		c=getchar();
	}
	while (c>='0'&&c<='9'){
		wh=(wh<<3)+(wh<<1)+(c^48);
		c=getchar();
	}
	return wh*fh;
} 
void dfs1(int u){
	p[u].dfn=++dtot,e[u].size=1;
	for (int i=S.head[u];i;i=S.e[i].next){
		int v=S.e[i].to;
		if (v!=e[u].fa){
			e[v].fa=u;
			e[v].dep=e[u].dep+1;
			dfs1(v);
			e[u].size+=e[v].size;
			if (e[v].size>e[e[u].zson].size) e[u].zson=v;
		}
	}
	p[u].low=++ltot,p[u].id=u;
}
void dfs2(int u,int tp){
	e[u].tp=tp;
	if (e[u].zson) dfs2(e[u].zson,tp);
	for (int i=S.head[u];i;i=S.e[i].next){
		int v=S.e[i].to;
		if (v!=e[u].fa&&v!=e[u].zson){
			dfs2(v,v);
		}
	}
}
bool cmp(node3 x,node3 y){
	return x.dfn<y.dfn;
}
int getlca(int x,int y){
	while (e[x].tp!=e[y].tp){
		int xp=e[x].tp,yp=e[y].tp;
		if (e[xp].dep>e[yp].dep) x=e[xp].fa;
		else y=e[yp].fa;
	}
	return e[x].dep<e[y].dep?x:y;
}
void build(){
	sort(jl+1,jl+1+jtot,cmp);
	for (int i=1;i<k;i++){
		jl[++jtot]=p[getlca(jl[i].id,jl[i+1].id)];
	}
	sort(jl+1,jl+1+jtot,cmp);
	stot=0;
//	for (int i=1;i<=jtot;i++) printf("%d ",jl[i].id);
//	puts("--------");
	st[++stot]=jl[1].id;
	for (int i=2;i<=jtot;i++){
		if (jl[i].dfn==jl[i-1].dfn) continue;
		while (stot&&(p[st[stot]].dfn>jl[i].dfn||p[st[stot]].low<jl[i].low)) --stot;
//		printf("%d %d\n",jl[i].id,st[stot]);
		if (st[stot]==e[jl[i].id].fa&&can[st[stot]]&&can[jl[i].id]){
			flag=1;
			puts("-1");
			break;
		}
		T.adde(st[stot],jl[i].id);
		st[++stot]=jl[i].id;
	}
}
void to_pre(){
	for (int i=1;i<=jtot;i++) can[jl[i].id]=0,T.head[jl[i].id]=0;
	T.tot=0;
}
int dp(int u){
//	printf("::%d\n",u);
	int sum=0;
	for (int i=T.head[u];i;i=T.e[i].next){
		int v=T.e[i].to;
		sum+=dp(v);
	}
	if (can[u]){
		ans+=sum;
		return 1;
	}else if (sum==0) return 0;
	else if (sum==1) return 1;
	ans++;return 0;
}
int main(){
	n=read();
	for (int i=1;i<n;i++){
		int u=read(),v=read();
		S.adde(u,v),S.adde(v,u);
	}
	dfs1(1),dfs2(1,1);
	m=read();
	while (m--){
		k=read();
		jtot=0;
		for (int i=1;i<=k;i++) jl[++jtot]=p[read()],can[jl[jtot].id]=1;
		build();
		if (!flag) ans=0,dp(st[1]),printf("%d\n",ans);
		else flag=0;
		to_pre();
	}
	return 0;
}

2.2 P2495 [SDOI2011]消耗战

给定一棵树,每次询问给定 \(k\) 个特殊点,需要断掉一些边使得从根节点无法到达任何特殊点,求最小需要断掉的边数。\(\sum k\le2n\).

同样,观察到给定特殊点以及关于所有询问特殊点总个数的限制,考虑使用虚树。

很容易想到对于一个虚树上的点,如果是特殊点,那么其实断掉其子树上的边已经没有用了,只能断掉该节点到根路径上的边;如果不是特殊点,则将所有子树返回上来的答案加起来之后,与到其父亲的边权取 \(\min\) 后上传给父亲(即要不断掉所有有特殊点的子树,要不断掉与根路径上的边)。最终 \(1\) 节点的答案即为最终答案。

注意由于敌军岛屿在 \(1\) 号,所以如果建立虚树后虚树的根不是 \(1\) 号节点的话,还需要将 \(1\) 号节点再向虚树的根节点连一条边,并将 \(1\) 号节点作为根节点(注意,多测清空的时候 \(1\) 号节点相关的信息也要清空!)。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 500005
int n,fa[N][21],dep[N],m,k,jtot,dtot,ltot,st[N],stot;
ll minn[N][21];
bool can[N];
struct node{
	int to,next;
	ll w;
	node (int to=0,int next=0,ll w=0)
		:to(to),next(next),w(w){}
};
struct node1{
	int head[N],tot;
	node e[N<<1];
	void adde(int u,int v,int w){
		e[++tot]=node(v,head[u],w);
		head[u]=tot;
	}
}S,T;
struct node2{
	int dfn,low,id;	
}p[N],jl[N<<1];
int read(){
	int wh=0,fh=1;
	char c=getchar();
	while (c>'9'||c<'0'){
		if (c=='-') fh=-1;
		c=getchar();
	}
	while (c>='0'&&c<='9'){
		wh=(wh<<3)+(wh<<1)+(c^48);
		c=getchar();
	}
	return wh*fh;
} 
void dfs(int u){
	p[u].dfn=++dtot,p[u].id=u;
	for (int i=S.head[u];i;i=S.e[i].next){
		int v=S.e[i].to;
		if (v!=fa[u][0]){
			fa[v][0]=u;
			dep[v]=dep[u]+1;
			minn[v][0]=S.e[i].w;
			for (int i=1;(1<<i)<=dep[v];i++) fa[v][i]=fa[fa[v][i-1]][i-1],minn[v][i]=min(minn[v][i-1],minn[fa[v][i-1]][i-1]);
			dfs(v);
		}
	}
	p[u].low=++ltot;
}
bool cmp(node2 x,node2 y){
	return x.dfn<y.dfn;
}
int getlca(int x,int y){
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=20;i>=0;i--)
		if (dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
	if (x==y) return x;
	for (int i=20;i>=0;i--)
		if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
ll query(int x,int y){
	ll sum=1e15;
	for (int i=20;i>=0;i--)
		if (dep[y]-(1<<i)>=dep[x]) sum=min(sum,minn[y][i]),y=fa[y][i];
	return sum;
}
void build(){
	sort(jl+1,jl+1+jtot,cmp);
	for (int i=1;i<k;i++) jl[++jtot]=p[getlca(jl[i].id,jl[i+1].id)];
	sort(jl+1,jl+1+jtot,cmp);
//	for (int i=1;i<=jtot;i++) printf("%d ",jl[i].id);
//	puts("--------");
	st[stot=1]=jl[1].id;
	for (int i=2;i<=jtot;i++){
		if (jl[i].dfn==jl[i-1].dfn) continue;
		while (stot&&(p[st[stot]].dfn>jl[i].dfn||p[st[stot]].low<jl[i].low)) --stot;
		T.adde(st[stot],jl[i].id,query(st[stot],jl[i].id));
		st[++stot]=jl[i].id;
	}
	if (st[1]!=1) T.adde(1,st[1],query(1,st[1]));
}
void to_pre(){
	can[1]=T.head[1]=0;
	for (int i=1;i<=jtot;i++) can[jl[i].id]=0,T.head[jl[i].id]=0;
	T.tot=0;
}
ll dp(int u){
	ll sum=0;
	for (int i=T.head[u];i;i=T.e[i].next){
		int v=T.e[i].to;
		ll w=min(dp(v),T.e[i].w);
//		printf("%d %d %d\n",u,v,w);
		sum+=w;
	}
	if (can[u]) return 1e12;
	return sum;
}
int main(){
	n=read();
	for (int i=1;i<n;i++){
		int u=read(),v=read(),w=read();
		S.adde(u,v,w),S.adde(v,u,w);
	}
	dfs(1);
	m=read();
	while (m--){
		k=read();
		jtot=0;
		for (int i=1;i<=k;i++) jl[++jtot]=p[read()],can[jl[i].id]=1;
		build();
		printf("%lld\n",dp(1));
		to_pre();
	}
	return 0;
}

2.3 P4103 [HEOI2014]大工程

给定一棵树,每次询问给定 \(k\) 个特殊点,求它们两两之间距离的距离和,最小距离和最大距离。\(\sum k\le2n\).

有了前两题的经验,很自然地就想到了虚树。

先想想普通方法怎么做。

分别考虑三个问题。距离和其实最好求,枚举每一条边,该边的贡献其实就是其左侧的特殊点个数和其右边的特殊点个数的乘积再乘上该边的边权,对所有边的贡献求和即为答案。

最小距离和最大距离方法类似,考虑树形 dp,对于每个节点分别其子树内过该节点的答案。对于最小距离,遍历到一个节点时,如果该节点为特殊点,则该节点子树内的最小距离显然是其子树内特殊点到该节点的最小距离;如果不是,最小距离则为其子树内特殊点到该节点的最小距离和次小距离之和。对于最大距离,则最大距离为子树内最大距离与次大距离之和,如果没有次大距离且当前节点为特殊点时,次大距离可以用 \(0\) 代替。而这些东西都可以通过简单的递归与返回处理。

那么其实虚树上的求解和普通树上的求解基本相同,只不过需要将虚树上的链权赋值为该链上边权的最小值,可以通过倍增或树剖轻松处理。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 1000005
int n,dtot,ltot,jtot,fa[N][22],dep[N],st[N<<1],stot,ans2,ans3,size[N],k;
ll ans1;
bool can[N];
struct node{
	int head[N],tot;
	int to[N<<1],next[N<<1],w[N<<1],from[N<<1];
	void adde(int u,int v,int ww){
		++tot;
		from[tot]=u,to[tot]=v,next[tot]=head[u],w[tot]=ww;
		head[u]=tot;
	}
}S,T;
struct node1{
	int id,dfn,low;
}p[N],jl[N<<1];
int read(){
	int wh=0,fh=1;
	char c=getchar();
	while (c>'9'||c<'0'){
		if (c=='-') fh=-1;
		c=getchar();
	}
	while (c>='0'&&c<='9'){
		wh=(wh<<3)+(wh<<1)+(c^48);
		c=getchar();
	}
	return wh*fh;
} 

void dfs(int u){
	p[u].dfn=++dtot;
	for (int i=S.head[u];i;i=S.next[i]){
		int v=S.to[i];
		if (v!=fa[u][0]){
			dep[v]=dep[u]+1;
			fa[v][0]=u;
			for (int i=1;(1<<i)<=dep[v];i++) fa[v][i]=fa[fa[v][i-1]][i-1];
			dfs(v);
		}
	}
	p[u].low=++ltot;
	p[u].id=u;
}
bool cmp(node1 x,node1 y){
	return x.dfn<y.dfn;
}
int getlca(int x,int y){
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=21;i>=0;i--)
		if (dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
	if (x==y) return x;
	for (int i=21;i>=0;i--)
		if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
void check(int u){
	printf("-----%d\n",u);
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i];
		check(v);
	}
}
void build(){
	k=read();
	jtot=0;
	for (int i=1;i<=k;i++) jl[++jtot]=p[read()],can[jl[i].id]=1;
	sort(jl+1,jl+1+jtot,cmp);
	for (int i=1;i<k;i++)
		jl[++jtot]=p[getlca(jl[i].id,jl[i+1].id)];
	sort(jl+1,jl+1+jtot,cmp);
	stot=0;
	st[++stot]=jl[1].id;
//	puts("------");
//	for (int i=1;i<=jtot;i++) printf("%d ",jl[i].id);
//	puts("\n------");
	for (int i=2;i<=jtot;i++){
		if (jl[i].dfn==jl[i-1].dfn) continue;
	//	printf("%d %d %d %d %d %d\n",st[stot],jl[i].id,p[st[stot]].dfn,jl[i].dfn,p[st[stot]].low,jl[i].low);
		while (stot&&(p[st[stot]].low<jl[i].low||p[st[stot]].dfn>jl[i].dfn)) --stot;
		T.adde(st[stot],jl[i].id,dep[jl[i].id]-dep[st[stot]]);
	//	printf("%d %d %d\n",st[stot],jl[i].id,dep[jl[i].id]-dep[st[stot]]);
		st[++stot]=jl[i].id;
	}
//	check(st[1]);
}
void dfs1(int u){
	if (can[u]) size[u]=1;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i];
		dfs1(v);
		size[u]+=size[v];	
	}
}
void dfs2(int u){
	size[u]=0;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i];
		dfs2(v);
	}
}
int query_sum(){
	dfs1(st[1]);
	for (int i=1;i<=T.tot;i++){
		int u=T.from[i],v=T.to[i];
		ans1+=(ll)size[v]*(ll)(k-size[v])*(ll)T.w[i];
	}
	dfs2(st[1]);
}
int query_max(int u){
	int maxn=0,cmax=0;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i],w=T.w[i];
		int now=query_max(v)+w;
		if (now>maxn) cmax=maxn,maxn=now;
		else if (now>cmax) cmax=now;
	}
	if (cmax||maxn&&can[u]) ans3=max(ans3,maxn+cmax);
	return maxn;
}
int query_min(int u){
	int minn=1e9,cmin=1e9;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i],w=T.w[i];
		int now=query_min(v)+w;
		if (now<minn) cmin=minn,minn=now;
		else if (now<cmin) cmin=now;
	}
	if (can[u]){
		if (minn) ans2=min(ans2,minn);
		return 0;
	}
	if (cmin)
		ans2=min(ans2,minn+cmin);
	return minn;
}
int main(){
	n=read();
	for (int i=1;i<n;i++){
		int u=read(),v=read();
		S.adde(u,v,1),S.adde(v,u,1);
	}
	dfs(1);
	int q=read();
	while (q--){
		build();
		ans1=ans3=0;ans2=1e9;
		query_sum(),query_min(st[1]),query_max(st[1]);
		printf("%lld %d %d\n",ans1,ans2,ans3);
		for (int i=1;i<=jtot;i++) T.head[jl[i].id]=0,can[jl[i].id]=0;
		T.tot=0;
	}
	return 0;
}
posted @ 2022-05-16 11:19  ydtz  阅读(64)  评论(0编辑  收藏  举报