虚树 学习笔记

虚树 学习笔记

引入

我们在解决树上问题时,往往都是对整棵树进行处理,或者每次询问都对一个点、点对进行处理,这类题型一般都可以通过 dp、树剖解决;然而,有一类问题要求我们每次对树上一些关键点进行处理。这类问题的特点就是询问次数多,而询问的点的总数不多。可如果我们每次都把整棵树都 dfs 一遍,时间复杂度就是 \(n^2\) 级别的。我们发现,每次 dfs 的时候,有用的只有关键点,我们所关注的也只有这些关键点之间的关系,那我们是不是可以考虑把整棵树抽象起来,变成一颗只有关键点和某些辅助点的树呢?

虚树

我们把关键点和用于体现关键点之间关系的辅助点连接起来,就形成了虚树。这些辅助点一般都是 LCA。由于至少需要一个关键点才会出现一个辅助点(比如某个点 是 “两个关键点的 LCA” 与 “另一个关键点” 的 LCA),所以最后建出来所有的虚树并遍历的总代价是 \(2n\) 级别的。

建树过程

我们肯定希望辅助点越少越好,但同时还得保证信息正确,所以我们考虑按照 dfs 序建树,因为 dfs 序越相近,两个点在树上的关系越近。我们先把关键点按 dfn 排序,然后用一个栈来维护一条虚树上的链,每次都询问栈顶是否为新点和栈顶的 LCA,如果不是,说明要开一条新的链,就弹栈并加边。最后一定要把栈内剩余元素加边。
参考代码:

void build(){
	sort(p+1, p+K+1, cmp);
	top = 0;
	stk[++top] = 1;
	G2.head[1] = 0;//注意不能全部清空,在加边的过程中动态清空即可。
	for(int i = 1; i<=K; ++i){
		if(p[i] == 1) continue;
		int lca = th.LCA(stk[top], p[i]);
		if(lca != stk[top]){
			while(dfn[lca] < dfn[stk[top-1]]){
				G2.add(stk[top-1], stk[top]);
				--top;
			}
			if(dfn[lca] > dfn[stk[top-1]]){
				G2.head[lca] = 0;
				G2.add(lca, stk[top]), stk[top] = lca;
			} else{
				G2.add(lca, stk[top]);
				--top;
			}
		}
		G2.head[p[i]] = 0;
		stk[++top] = p[i];
	}
	for(int i = 1; i<top; ++i){
		G2.add(stk[i], stk[i+1]);
	}
} 

例题

消耗战

首先 dp 式子很明显,我们分类讨论。如果子节点 \(v\) 是关键点,那么 \(f_u+=w(u, v)\);如果不是,那么就是 \(f_u+= \min{f_v, w(u, v)}\)。我们可以预处理出来根节点到每个节点的路径上所经过的最小边权,把它作为虚树上的边权即可。
代码:

点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2.5e5+10;

inline int read(){
	int x = 0; char ch = getchar();
	while(ch<'0' || ch>'9') ch = getchar();
	while(ch>='0'&&ch<='9') x = x*10+ch-48, ch = getchar();
	return x;
}
struct node{
	int nxt, to, w;
};
struct Graph{
	int head[N], tot;
	int num;
	node edge[N<<1];
	void add(int u, int v, int w){
		edge[++tot].nxt = head[u];
		edge[tot].to = v;
		edge[tot].w = w;
		head[u] = tot;
	}
}G1, G2;//原树,虚树。 

int dfn[N];
struct HPD{//重链剖分,heavy path decomposition 
	int siz[N], totd, top[N], son[N], dep[N], fa[N];
	void dfs1(int u, int fath){
		dep[u] = dep[fath]+1;
		siz[u] = 1;
		fa[u] = fath;
		for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
			int v = G1.edge[i].to;
			if(v == fath) continue;
			dfs1(v, u);
			siz[u]+=siz[v];
			if(siz[son[u]]<siz[v]) son[u] = v;
		}
	}
	void dfs2(int u, int Top){
		top[u] = Top;
		dfn[u] = ++totd;
		if(!son[u]) return;
		dfs2(son[u], Top);
		for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
			int v = G1.edge[i].to;
			if(!dfn[v]) dfs2(v, v);
		}
	}
	int LCA(int x, int y){
		while(top[x] != top[y]){
			if(dep[top[x]] < dep[top[y]]) swap(x, y);
			x = fa[top[x]];
		}
		if(dep[x] > dep[y]) swap(x, y);
		return x;
	}
}th; 

int n;
int m, K;
int dst[N], p[N];

void dfsG1(int u, int fath){
	for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
		int v = G1.edge[i].to;
		if(v == fath) continue;
		dst[v] = min(dst[u], G1.edge[i].w);
		dfsG1(v, u);
	}
}

bool cmp(int a, int b){
	return dfn[a] < dfn[b];
}
int stk[N], tp;
bool is_tar[N];
void build(){
	sort(p+1, p+K+1, cmp);
	tp = 0;
	stk[++tp] = 1, G2.head[1] = 0;
	for(int i = 1, l; i<=K; ++i){
		if(p[i] == 1) continue;
		l = th.LCA(p[i], stk[tp]);
		if(l != stk[tp]){
			while(dfn[l] < dfn[stk[tp-1]]){
				G2.add(stk[tp-1], stk[tp], dst[stk[tp]]);
				--tp;
			}
			if(dfn[l] > dfn[stk[tp-1]]){
				G2.head[l] = 0;
				G2.add(l, stk[tp], dst[stk[tp]]), stk[tp] = l;
			} else{
				G2.add(l, stk[tp], dst[stk[tp]]);
				--tp;
			}
			
		}
		G2.head[p[i]] = 0;
		stk[++tp] = p[i];
	}
	for(int i = 1; i<tp; ++i){
		G2.add(stk[i], stk[i+1], dst[stk[i+1]]);	
	}
}	

ll f[N];
void dfs_ans(int u, int fath){
	f[u] = 0;
	for(int i = G2.head[u]; i; i = G2.edge[i].nxt){
		int v = G2.edge[i].to;
		if(v == fath) continue;
		dfs_ans(v, u);
		if(is_tar[v]){
			f[u]+=G2.edge[i].w;
		} else{
			f[u]+= min(f[v], 1ll*G2.edge[i].w);
		}
	}
}
int main(){
	n = read();
	dst[1] = 0x3f3f3f3f;
	for(int i = 1; i<n; ++i){
		int u = read(), v = read(), w = read();
		G1.add(u, v, w);
		G1.add(v, u, w); 
	}
	th.dfs1(1, 0);
	th.dfs2(1, 1);
	dfsG1(1, 0);
	m = read();
	while(m--){
		K = read();
		G2.tot = 0;//一定注意要清空!
		for(int i = 1; i<=K; ++i){
			p[i] = read();
			is_tar[p[i]] = 1;
		} 
		build();
		dfs_ans(1, 0);
		printf("%lld\n", f[1]);
		for(int i = 1; i<=K; ++i){
			is_tar[p[i]] = 0;
		}
	}
	return 0;
} 

大工程

也是考虑建好虚树后怎么做。最大值和最小值都可以通过拼接求得,每次找出最大/最小和次大/次小值拼接即可。至于路径权值和,我们考虑每条边的贡献,发现就等于这条边所连接的两棵子树中关键点数量的乘积。至于建好虚树后新边的边权,因为是单位边权,所以直接通过两点的深度做差即可求得。
代码:

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e6+100;
const int INF = 0x3f3f3f3f;

inline int read(){
	int x = 0; char ch = getchar();
	while(ch<'0' || ch>'9') ch = getchar();
	while(ch>='0'&&ch<='9') x = x*10+ch-48, ch = getchar();
	return x;
}
struct node{
	int nxt, to;
};
struct Graph{
	int tot, head[N];
	node edge[N<<1];
	void add(int u, int v){
		edge[++tot].nxt = head[u];
		edge[tot].to = v;
		head[u] = tot;
	}
}G1, G2;
int dep[N], dfn[N], totd;
struct HPD{
private:
	int fa[N], top[N], son[N], siz[N];
public:
	void dfs1(int u, int fath){
		dep[u] = dep[fath]+1;
		fa[u] = fath;
		siz[u] = 1;
		for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
			int v = G1.edge[i].to;
			if(v == fath) continue;
			dfs1(v, u);
			siz[u]+=siz[v];
			if(siz[son[u]] < siz[v]) son[u] = v;
		}
	}
	void dfs2(int u, int Top){
		dfn[u] = ++totd;
		top[u] = Top;
		if(!son[u]) return;
		dfs2(son[u], Top);
		for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
			int v = G1.edge[i].to;
			if(!dfn[v]) dfs2(v, v);
		}
	}
	inline int LCA(int x, int y){
		while(top[x] != top[y]){
			if(dep[top[x]] < dep[top[y]]) swap(x, y);
			x = fa[top[x]];
		}
		if(dep[x] > dep[y]) swap(x, y);
		return x;
	}
}th;

int K;
int stk[N], top;
int p[N];
bool is_tar[N];
bool cmp(int x, int y){
	return dfn[x] < dfn[y];
}
void build(){
	sort(p+1, p+K+1, cmp);
	top = 0;
	stk[++top] = 1;
	G2.head[1] = 0;
	for(int i = 1; i<=K; ++i){
		if(p[i] == 1) continue;
		int lca = th.LCA(stk[top], p[i]);
		if(lca != stk[top]){
			while(dfn[lca] < dfn[stk[top-1]]){
				G2.add(stk[top-1], stk[top]);
				--top;
			}
			if(dfn[lca] > dfn[stk[top-1]]){
				G2.head[lca] = 0;
				G2.add(lca, stk[top]), stk[top] = lca;
			} else{
				G2.add(lca, stk[top]);
				--top;
			}
		}
		G2.head[p[i]] = 0;
		stk[++top] = p[i];
	}
	for(int i = 1; i<top; ++i){
		G2.add(stk[i], stk[i+1]);
	}
} 
int fmn[N], fmx[N]; long long fsum[N];
int mn, mx;
long long sum;
void dfs_ans(int u, int fath){
	int firmn = INF, secmn = INF;
	fmn[u] = INF, fmx[u] = 0;
	int firmx = 0, secmx = 0;
	fsum[u] = 0;
	if(is_tar[u]){
		fsum[u] = 1;
	}
	for(int i = G2.head[u]; i; i = G2.edge[i].nxt){
		int v = G2.edge[i].to;
		if(v == fath) continue;
		dfs_ans(v, u);
		if(is_tar[v]){
			fmn[u] = min(fmn[u], dep[v]-dep[u]);
			if(dep[v]-dep[u] < firmn){
				secmn = firmn;
				firmn = dep[v]-dep[u];
			} else if(dep[v]-dep[u]<secmn){
				secmn = dep[v]-dep[u];
			}
		} else{
			fmn[u] = min(fmn[v]+dep[v]-dep[u], fmn[u]);
			if(fmn[v]+dep[v]-dep[u] < firmn){
				secmn = firmn;
				firmn = fmn[v]+dep[v]-dep[u];
			} else if(fmn[v]+dep[v]-dep[u]<secmn){
				secmn = fmn[v]+dep[v]-dep[u];
			}
		}
		fmx[u] = max(fmx[u], fmx[v]+dep[v]-dep[u]);
		if(fmx[v]+dep[v]-dep[u] > firmx){
			secmx = firmx;
			firmx = fmx[v]+dep[v]-dep[u];
		} else if(fmx[v]+dep[v]-dep[u] > secmx){
			secmx = fmx[v]+dep[v]-dep[u];
		}
		fsum[u]+=fsum[v];
		sum+=(fsum[v]*(K-fsum[v])*(dep[v]-dep[u]));
	}
	if(is_tar[u]){
		mn = min(mn, fmn[u]);
	} else{
		mn = min(mn, secmn+firmn);
	}
	if(secmx){
		mx = max(mx, firmx+secmx);
	} else if(is_tar[u]){
		mx = max(fmx[u], mx);
	}
}
int n, Q;
int main(){
	n = read();
	for(int i = 1; i<n; ++i){
		int u = read(), v = read();
		G1.add(u, v);
		G1.add(v, u);
	}
	th.dfs1(1, 0);
	th.dfs2(1, 1);
	Q = read();
	while(Q--){
		K = read();
		G2.tot = 0;
		for(int i = 1; i<=K; ++i){
			p[i] = read();
			is_tar[p[i]] = 1;
		}
		build();
		sum = mx = 0, mn = INF;
		dfs_ans(1, 0);
		printf("%lld %d %d\n", sum, mn, mx);
		for(int i = 1; i<=K; ++i){
			is_tar[p[i]] = 0;
		}
	}
	return 0;
}
posted @ 2023-07-13 15:03  霜木_Atomic  阅读(18)  评论(0编辑  收藏  举报