虚树

消耗战

首先考虑朴素 \(dp\),设 \(f_u\) 表示使 \(u\) 的子树内的所有关键点都不与 \(u\) 连通的最小代价。如果当前在 \(u\)\(j\) 这个儿子是关键点,那么有转移 \(f_u\leftarrow f_u+val_{u\rightarrow j}\);否则有转移 \(f_u\leftarrow f_u+\min(f_j,val_{u\rightarrow j})\)

这样做的时间复杂度是 \(O(nm)\) 的,我们考虑优化。

暴力代码:

#include<bits/stdc++.h>
#define int long long
#define N 250005
#define M 500005
using namespace std;
int n,m,h[N],e[M],w[M],ne[M],idx,x[N],f[N];
bool st[N];
void add(int a,int b,int c){
	e[idx]=b;w[idx]=c;ne[idx]=h[a];h[a]=idx++;
}
void dfs(int u,int fa){
	f[u]=0;
	for(int i=h[u];~i;i=ne[i]){
		int j=e[i];
		if(j==fa)continue;
		dfs(j,u);
		if(st[j])f[u]+=w[i];
		else f[u]+=min(f[j],w[i]);
	}
}
signed main(){
	cin>>n;
	memset(h,-1,sizeof h);
	for(int i=1;i<n;i++){
		int a,b,c;
		cin>>a>>b>>c;
		add(a,b,c);add(b,a,c);
	}
	cin>>m;
	while(m--){
		int k;
		cin>>k;
		for(int i=1;i<=k;i++){
			cin>>x[i];
			st[x[i]]=1;
		}
		dfs(1,0);
		cout<<f[1]<<'\n';
		for(int i=1;i<=k;i++){
			st[x[i]]=0;
		}
	}
	return 0;
}

发现 \(\sum k\) 非常小,那么我们能不能把 \(dp\) 的总复杂度控制在 \(O(\sum k)\) 呢?

我们接下来说明如何建虚树:

首先加入所有关键点,然后把他们按照 \(dfs\) 序排序,接着加入相邻两点的 \(lca\),最后对这些点去重,建树。首先这样所有询问的总点数不超过 \(2\times \sum k\),所以复杂度是对的。

那么,为什么这样做是对的呢?考虑两个 \(dfs\) 序相邻的点 \(x,y\) 他们的 \(lca\)\(fa\)。于是我们只需要连接 \(fa\rightarrow y\)。这样是不重不漏的。分类讨论一下:

  • \(x\)\(y\) 的祖先,则 \(fa\)\(x\),因为 \(x,y\)\(dfs\) 序相邻,所以 \(x\rightarrow y\) 上没有其他关键点。

  • \(x\) 不是 \(y\) 的祖先,那么把 \(fa\) 当作 \(y\) 的祖先,\(fa\rightarrow y\) 同上可得没有关键点。

代码:

#include<bits/stdc++.h>
#define int long long
#define N 250005
#define K 20
#define pii pair<int,int>
#define x first
#define y second
using namespace std;
int n,m,k,h[N],a[N],cnt,mn[N];
int dfn[N],tot,f[N];
int fa[N][K],dep[N];
bool st[N];
vector<pii>g[N];
vector<int>e[N];
void add(int a,int b){
	e[a].push_back(b);
}
void add1(int a,int b,int c){
	g[a].push_back({b,c});
}
void dfs(int u,int pre){
	fa[u][0]=pre;
	for(int i=1;i<K;i++){
		fa[u][i]=fa[fa[u][i-1]][i-1];
	}
	dfn[u]=++tot;
	dep[u]=dep[pre]+1;
	for(auto it:g[u]){
		int j=it.x,val=it.y;
		if(j==pre)continue;
		mn[j]=min(mn[u],val);
		dfs(j,u);
	}
}
int get_lca(int a,int b){
	if(dep[a]<dep[b])swap(a,b);
	for(int i=K-1;~i;i--){
		if(dep[fa[a][i]]>=dep[b]){
			a=fa[a][i];
		}
	} 
	if(a==b)return a;
	for(int i=K-1;~i;i--){
		if(fa[a][i]!=fa[b][i]){
			a=fa[a][i];
			b=fa[b][i];
		}
	}
	return fa[a][0];
}
int dp(int u){
	if(e[u].size()==0)return mn[u];
	int sum=0;
	for(auto j:e[u]){
		sum+=dp(j);
	}
	e[u].clear();
	if(st[u])return mn[u];
	else return min(sum,mn[u]);
}
void build(){
	sort(h+1,h+k+1,[&](int x,int y){
		return dfn[x]<dfn[y];
	});
	for(int i=1;i<k;i++){
		a[++cnt]=h[i];
		a[++cnt]=get_lca(h[i],h[i+1]);
	}
	a[++cnt]=h[k];
	a[++cnt]=1;
	sort(a+1,a+cnt+1,[&](int x,int y){
		return dfn[x]<dfn[y];
	});
	cnt=unique(a+1,a+cnt+1)-a-1;
	for(int i=1;i<cnt;i++){
		int lca=get_lca(a[i],a[i+1]);
		add(lca,a[i+1]);
	}
}
signed main(){
	cin>>n;
	memset(mn,0x3f,sizeof mn);
	for(int i=1;i<n;i++){
		int a,b,c;
		cin>>a>>b>>c;
		add1(a,b,c);add1(b,a,c);
	}
	dfs(1,0);
	cin>>m;
	while(m--){
		cin>>k;
		for(int i=1;i<=k;i++){
			cin>>h[i];
			st[h[i]]=1;
		}
		cnt=0;
		build();
		cout<<dp(1)<<'\n';
		for(int i=1;i<=k;i++){
			st[h[i]]=0;
		}
	}
	return 0;
}

Kingdom and its Cities

仍然先考虑朴素 \(dp\),考虑设 \(f_u\) 表示以 \(u\) 为根的子树的答案,\(g_u\) 表示以 \(u\) 为根的子树中有没有关键点与 \(u\) 连通。于是我们对转移方程分类讨论:

  • \(u\) 为关键点,将他和没断的子孙全部断掉,于是对于 \(u\) 的儿子 \(j\)\(f_u\leftarrow f_u+\sum g_j\)

  • \(u\) 不是关键点,再分讨一下:

    • \(1\) 个以上子孙关键点没断,把自己切掉即可:\(f_u\leftarrow f_u+1\)

    • 只有 \(1\) 个关键点没断,则留到祖先去:\(g_u\leftarrow 1\)

    • 没有关键点没断:\(g_u\leftarrow 0\)

暴力代码:

#include<bits/stdc++.h>
#define int long long
#define N 100005
using namespace std;
int n,m,f[N],g[N],fa[N];
int dfn[N],cnt,h[N];
bool st[N];
vector<int>e[N],e1[N];
void add(int a,int b){
	e[a].push_back(b);
}
void add1(int a,int b){
	e1[a].push_back(b);
}
void dfs(int u,int pre){
	f[u]=0;
	int sum=0;
	for(auto j:e1[u]){
		if(j==pre)continue;
		dfs(j,u);
		sum+=g[j];
		f[u]+=f[j];
	}
	if(g[u]){
		f[u]+=sum;
	}
	else{
		if(sum>1)f[u]++;
		else if(sum==1)g[u]=1;
	}
}
void dfs1(int u,int pre){
	fa[u]=pre;
	dfn[u]=++cnt;
	for(auto j:e1[u]){
		if(j==pre)continue;
		dfs1(j,u);
	}
}
signed main(){
	cin>>n;
	for(int i=1;i<n;i++){
		int a,b;
		cin>>a>>b;
		add1(a,b);add1(b,a);
	}
	dfs1(1,0);
	cin>>m;
	while(m--){
		int k;
		cin>>k;
		for(int i=1;i<=k;i++){
			cin>>h[i];
			g[h[i]]=1;
		}
		bool flag=1;
		for(int i=1;i<=k;i++){
			if(g[fa[h[i]]]){
				flag=0;
				break;
			}
		}
		if(!flag){
			cout<<"-1\n";
			for(int i=1;i<=n;i++){
				g[i]=0;
			}
			continue;
		}
		dfs(1,0);
		cout<<f[1]<<'\n';
		for(int i=1;i<=n;i++){
			g[i]=0;
		}
	}
	return 0;
}

然后发现 \(\sum k_i\le 10^5\),于是考虑使用虚树,直接构建出来跑一个 \(dp\) 即可。

代码:

#include<bits/stdc++.h>
#define int long long
#define N 100005
#define K 20
using namespace std;
int n,m,k,f[N],g[N],fa[N][K],dep[N];
int dfn[N],idx,h[N],a[N],cnt;
bool st[N];
vector<int>e[N],e1[N];
void add(int a,int b){
	e[a].push_back(b);
}
void add1(int a,int b){
	e1[a].push_back(b);
}
int dp(int u){
	int now=0,res=0;
	for(auto j:e[u]){
		res+=dp(j);
		now+=g[j];
	}
	if(st[u]){
		res+=now;
		g[u]=1;
	}
	else if(now>1){
		res++;
		g[u]=0;
	}
	else{
		if(now)g[u]=1;
		else g[u]=0;
	}
	e[u].clear();
	st[u]=0;
	return res;
}
void dfs1(int u,int pre){
	fa[u][0]=pre;
	for(int i=1;i<K;i++){
		fa[u][i]=fa[fa[u][i-1]][i-1];
	}
	dfn[u]=++idx;
	dep[u]=dep[pre]+1;
	for(auto j:e1[u]){
		if(j==pre)continue;
		dfs1(j,u);
	}
}
int get_lca(int a,int b){
	if(dep[a]<dep[b])swap(a,b);
	for(int i=K-1;~i;i--){
		if(dep[fa[a][i]]>=dep[b]){
			a=fa[a][i];
		}
	}
	if(a==b)return a;
	for(int i=K-1;~i;i--){
		if(fa[a][i]!=fa[b][i]){
			a=fa[a][i];
			b=fa[b][i];
		}
	}
	return fa[a][0];
}
void build(){
	sort(h+1,h+k+1,[&](int x,int y){
		return dfn[x]<dfn[y];
	});
	cnt=0;
	for(int i=1;i<k;i++){
		a[++cnt]=h[i];
		a[++cnt]=get_lca(h[i],h[i+1]);
	}
	a[++cnt]=h[k];
	a[++cnt]=1;
	sort(a+1,a+cnt+1,[&](int x,int y){
		return dfn[x]<dfn[y];
	});
	cnt=unique(a+1,a+cnt+1)-a-1;
	for(int i=1;i<cnt;i++){
		int lca=get_lca(a[i],a[i+1]);
		add(lca,a[i+1]);
	}
}
signed main(){
	cin>>n;
	for(int i=1;i<n;i++){
		int a,b;
		cin>>a>>b;
		add1(a,b);add1(b,a);
	}
	dfs1(1,0);
	cin>>m;
	while(m--){
		cin>>k;
		for(int i=1;i<=k;i++){
			cin>>h[i];
			st[h[i]]=1;
		}
		bool flag=1;
		for(int i=1;i<=k;i++){
			if(st[fa[h[i]][0]]){
				flag=0;
				break;
			}
		}
		if(!flag){
			cout<<"-1\n";
			for(int i=1;i<=k;i++){
				st[h[i]]=0;
			}
			continue;
		}
		build();
		cout<<dp(1)<<'\n';
	}
	return 0;
}
posted @ 2024-09-09 11:57  zxh923  阅读(5)  评论(0编辑  收藏  举报