虚树学习笔记

虚树学习笔记

[SDOI2011]消耗战

Link

题意

给一棵\(n\)个点,带边权的树。

\(m\)组询问,每组有\(k_i\)个关键点,你需要切断一些边,使得每个点都到不了根节点,求最小代价。\

\(n <= 2.5 \cdot 10^5, m <= 5 \cdot 10^5,\sum k_i <= 5 \cdot 10^5\)

Solve 1

对于每组询问,做一个\(dp\),设\(f[x]\)表示切断\(x\)和他的子树所需最小代价,转移分两种

  1. \(x\)是关键点,答案为\(x\)到根路径最小值
  2. \(x\)不是关键点,答案为切断所有儿子的值和第\(1\)种取\(min\)

复杂度\(O(nm)\)

\(can \ we \ do \ better?\)

Solve 2

要用到所讲的虚树。

我们发现转移过程中,对于转移有贡献的只有关键点以及他们之间的祖先,于是我们可以简化树的结构。

把关键点按\(dfs\)序排序,相邻两个求出\(lca\)并建边。最后在虚树上做\(dp\),复杂度\(O(n\ log \ n + \sum k_i log \ n)\)

具体实现用一个栈维护一条树链,排序后一次加入点。

设当前加入的点\(u\)

  • 如果\(top <=1\) ,\(stk[++top] = u\)
  • \(l = lca(u,stk[top])\),如果\(l == stk[top]\),那么\(u\)应该接在\(stk[top]\)底下,\(stk[++top] = u\)
  • 否则说明\(u\)已经是一个新的子树,持续弹栈直到\(dfn[stk[top-1]] < dfn[l] <= dfn[stk[top]]\),如果\(l != stk[top]\),把\(stk[top]\)接在\(l\)后面,\(stk[top] = l\),最后\(stk[++top] = u\)
void insert(int u){
	if(top <= 1) return stk[++top] = u,void();
	int l = lca(u,stk[top]);
	if(l == stk[top]) return stk[++top] = u,void();
	while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
		add(stk[top-1],stk[top]); top--;
	}
	if(l != stk[top]) add(l,stk[top]),stk[top] = l;
	stk[++top] = u;
	return ;
}

Code

#include<bits/stdc++.h>
#define int long long
#define N 1000015
#define rep(i,a,n) for (int i=a;i<=n;i++)
#define per(i,a,n) for (int i=n;i>=a;i--)
#define inf 0x3f3f3f3f3f3f3f3f
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define lowbit(i) ((i)&(-i))
#define VI vector<int>
#define all(x) x.begin(),x.end()
using namespace std;
int n,m,a[N],Min[N],k,dfn[N],clk;
vector<pii> e[N];
VI g[N];
void dfs(int u,int fa){
	dfn[u] = ++clk;
	for(auto I:e[u]){
		int v = I.fi,w = I.se;
		if(v == fa) continue;
		Min[v] = min(Min[u],w);
		dfs(v,u);
	}
}
bool cmp(int u,int v){
	return dfn[u] < dfn[v];
}
namespace LCA{
	int fa[N][24],dep[N];
	void Dfs(int u,int f){
		fa[u][0] = f; dep[u] = dep[f]+1;
		for(auto I:e[u]){
			int v = I.fi;
			if(v == f) continue;
			Dfs(v,u);
		}
	}
	void init(){
		rep(j,1,21){
			rep(i,1,n){
				fa[i][j] = fa[fa[i][j-1]][j-1];
			}
		}
	}
	int lca(int u,int v){
		if(dep[u] < dep[v]) swap(u,v);
		int t = dep[u] - dep[v];
		per(i,0,21){
			if((1<<i)&t) u = fa[u][i];
		}
		if(u == v) return u;
		per(i,0,21){
			if(fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i];
		}
		return fa[u][0];
	}
}
using namespace LCA;
int stk[N],top;
void add(int u,int v){
	//printf("%lld -> %lld\n",u,v);
	g[u].pb(v);
}
void insert(int u){
	if(top <= 1) return stk[++top] = u,void();
	int l = lca(u,stk[top]);
	if(l == stk[top]) return stk[++top] = u,void();
	while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
		add(stk[top-1],stk[top]); top--;
	}
	if(l != stk[top]) add(l,stk[top]),stk[top] = l;
	stk[++top] = u;
	return ;
}
void build(){
	top = 0;
	stk[++top] = 1;
	rep(i,1,k) insert(a[i]);
	while(top > 1) add(stk[top-1],stk[top]),top--;
}
bool gkp[N];
int dp(int u){
	int res = 0;
	if(g[u].size() == 0){
		//printf("u: %lld val: %lld\n",u,Min[u]);
		return Min[u];
	}
	for(auto v:g[u]){
		res += dp(v);
	}
	g[u].clear();
	if(!gkp[u]) return min(res,Min[u]);
	//printf("u: %lld val: %lld\n",u,res);
	return Min[u];
}

signed main(){
	//freopen(".in","r",stdin);
	//freopen(".out","w",stdout);
 	scanf("%lld",&n);
 	memset(Min,0x3f,sizeof Min);
 	rep(i,2,n){
 		int u,v,w; scanf("%lld%lld%lld",&u,&v,&w);
 		e[u].pb(mp(v,w)); e[v].pb(mp(u,w));
 	}
 	dfs(1,0);
 	// rep(i,1,n) printf("%lld ", Min[i]);
 	// printf("\n");
 	Dfs(1,0); init();
 	// rep(i,1,n){
 	// 	rep(j,i+1,n){
 	// 		printf("(i,j): (%lld,%lld) lca: %lld\n",i,j,lca(i,j));
 	// 	}
 	// }
 	scanf("%lld",&m);
 	rep(_,1,m){
 		scanf("%lld",&k); rep(i,1,k) scanf("%lld",&a[i]),gkp[a[i]] = 1;
 		sort(a+1,a+k+1,cmp); //puts("sort finished");
 		build(); //puts("build finished");
 		printf("%lld\n",dp(1));
 		rep(i,1,k) gkp[a[i]] = 0;
 	}
	return 0;
}
posted @ 2021-01-23 12:46  趁着胆子小  阅读(93)  评论(1编辑  收藏  举报
//explotion effect (unabled)