虚树详解

我们先从一道经典的例题入手:

[SDOI2011]消耗战

题意:

给出一棵树,每条边有边权。

有m次询问,每次询问给出k个点,问使得这k个点均不与1号点(根节点)相连的最小代价

\(1\leq n\leq 2.5\times 10^5,1\leq m\leq 5\times 10^5,1\leq \sum\limits k\leq 5\times 10^5\)

暴力dp:

设dp[u]为以u为根的子树中,割掉所有给定点的最小代价

转移的时候要分两种情况:

1.若u不是给定点,则dp[u] = min(u到根节点的所有边的最小边长,割掉所有含有给定点的子树)

2.若u是给定点,显然他必须与1号点分离,所以dp[u]=u到根节点的所有边的最小边长

这为什么对呢?为什么不会重复计算呢?

因为我每次取min,如果一条边被选过一次了,dp值\(\geq\)只选一次这条边的值

每次都对整棵树dp,复杂度是\(O(nm)\)

这部分代码:(可能是我写的比较丑?好像有人可以50,我只有30)

void dfs1(int x,int fa){
	for (int i = ed1.head[x];i;i = ed1.nxt[i]){
		int to = ed1.to[i];
		if (to == fa) continue;
		minn[to] = min(minn[x],ed1.w[i]);
		dfs1(to,x);
	}
}
void dfs2(int x,int fa){
	int res = 0;
	for (int i = ed1.head[x];i;i = ed1.nxt[i]){
		int to = ed1.to[i];
		if (to == fa) continue;
		//cout<<"-----"<<x<<" "<<to<<" "<<ed1.w[i]<<endl;
		dfs2(to,x);
		res += dp[to];
	}
	if (vis[x]) dp[x] = minn[x];
	else dp[x] = min(minn[x],res);
}

虚树优化

上述复杂度肯定是不行的,我们发现\(\sum\limits k\)比较小,那么我们从这里入手,来建虚树

虚树的主要思想是:对于一棵树,仅仅保留有用的点,重新构建一棵树

这里有用的点指的是询问点和它们的lca

构建:

首先我们要先对整棵树dfs一遍,求出他们的dfs序,然后对每个节点以dfs序为关键字从小到大排序

同时维护一个栈,表示从根到栈顶元素这条链

  • 如果栈为空,那么显然st[1] = x;

  • 取LCA = lca(x,st[top]),如果LCA\(\neq\)st[top],将lca底下的链边删边连边

  • 如果删完发现LCA不在栈中,将LCA加入栈中,然后再把x加入栈中

void build(int x){
	if (top == 0){st[top = 1] = x;return;}
	int LCA = lca(x,st[top]);
	while (top > 1&&dep[LCA] < dep[st[top-1]]) ed2.add(st[top-1],st[top]),top--;
	if (dep[LCA] < dep[st[top]]) ed2.add(LCA,st[top--]);
	if (top == 0||LCA != st[top]) st[++top] = LCA;
	st[++top] = x;
} 
for (int i = 1;i <= q;i++) build(k[i]);
if (top) while (--top) ed2.add(st[top],st[top+1]);

复杂度

因为每次加入新的节点,最多会产生一个新的LCA,那么点数是\(2\times k\)的,复杂度为\(O(2k)\)

完整代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
#define ll long long
using namespace std;
int read(){
	int x = 1,a = 0;char ch = getchar();
	while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
	while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
	return x*a;
}
const int maxn = 5e5+7,inf = 1e18+7;
int n,m;
bool vis[maxn];
struct node{
	int to[maxn],nxt[maxn],tot,head[maxn],w[maxn];
	node(){tot = 0;memset(head,0,sizeof(head));}
	void add(int x,int y,int z){
		to[++tot] = y,nxt[tot] = head[x],w[tot] = z,head[x] = tot;
		to[++tot] = x,nxt[tot] = head[y],w[tot] = z,head[y] = tot;
	}
	void add(int x,int y){
		to[++tot] = y,nxt[tot] = head[x],head[x] = tot;
		to[++tot] = x,nxt[tot] = head[y],head[y] = tot;
	}
}ed1,ed2;
int st[maxn],top,minn[maxn],dfn[maxn],f[maxn][30],siz[maxn],dep[maxn],cnt;
void dfs(int x){
	dep[x] = dep[f[x][0]]+1,dfn[x] = ++cnt,siz[x] = 1;
	for (int i = 1;i <= 22;i++) f[x][i] = f[f[x][i-1]][i-1];
	for (int i = ed1.head[x];i;i = ed1.nxt[i]){
		int to = ed1.to[i];
		if (to == f[x][0]) continue;
		f[to][0] = x;
		minn[to] = min(minn[x],ed1.w[i]);
		dfs(to);
		siz[x] += siz[to];
	}
}
int lca(int x,int y){
	if (dep[x] > dep[y]) swap(x,y);
	for (int i = 22;i >= 0;i--){
		if (dep[f[y][i]] >= dep[x]) y = f[y][i];
	}
	if (x == y) return x;
	for (int i = 22;i >= 0;i--){
		if (f[x][i] != f[y][i]) x = f[x][i],y = f[y][i];
	}
	return f[x][0];
}
bool cmp(int x,int y){return dfn[x] < dfn[y];}
void build(int x){
	if (top == 0){st[top = 1] = x;return;}
	int LCA = lca(x,st[top]);
	while (top > 1&&dep[LCA] < dep[st[top-1]]) ed2.add(st[top-1],st[top]),top--;
	if (dep[LCA] < dep[st[top]]) ed2.add(LCA,st[top--]);
	if (top == 0||LCA != st[top]) st[++top] = LCA;
	st[++top] = x;
} 
ll dp[maxn];
void dfs1(int x,int fa){
	int res = 0;
	for (int i = ed2.head[x];i;i = ed2.nxt[i]){
		int to = ed2.to[i];
		if (to == fa) continue;
		dfs1(to,x);
		res += dp[to];
	}
	if (vis[x]) dp[x] = minn[x];
	else dp[x] = min(minn[x],res);
	vis[x] = ed2.head[x] = 0;
}
int k[maxn];
void init(){ed2.tot = top = 0;}
signed main(){
	n = read();
	for (int i = 1;i < n;i++){
		int x = read(),y = read(),z = read();
		ed1.add(x,y,z);
	}
	for (int i = 1;i <= n;i++) minn[i] = inf;
	m = read();
	dfs(1);
	while (m--){
		init();
		int q = read();
		for (int i = 1;i <= q;i++) k[i] = read(),vis[k[i]] = 1;
		k[++q] = 1;
		sort(k+1,k+q+1,cmp);
		for (int i = 1;i <= q;i++) build(k[i]);
		if (top) while (--top) ed2.add(st[top],st[top+1]);
		dfs1(1,0);
		printf("%lld\n",dp[1]);
	}
	return 0;
}
posted @ 2021-05-08 10:12  小又又yyyy  阅读(113)  评论(0编辑  收藏  举报