最小斯坦纳树

转自(稍加修改)

最小斯坦纳树,就是在一个无向连通图要花费最小的代价,连通给定的 \(k\) 个关键点(一般 \(k\le 10\)),这是一个组合优化问题。

这个问题可以用状压 DP 来解决,首先容易发现一个结论:

答案一定是树。你猜为啥叫最小斯坦纳树。

证明:如果答案存在环,则删去环上任意一条边,代价变小。

于是我们为这棵树钦定一个树根,设 \(f(S,i)\) 表示以 \(i\) 为根的一棵树,包含集合 \(S\) 中所有点的最小代价(只考虑关键点,即 \(S\) 是关键点集合的子集)。

考虑如何不重不漏地转移。

一棵以 \(i\) 为根的树有两种情况,第一种是 \(i\)\(deg=1\),另一种是 \(deg>1\)

对于 \(deg=1\) 的情况,可以考虑枚举树上与 \(i\) 相邻的点 \(j\),则:

\[f(S,j)+w(i,j)\to f(S,i)\quad(A) \]

对于 \(deg>1\) 的情况,可以划分成几个子树考虑,即:

\[f(T,i)+f(S-T,i)\to f(S,i)\quad(T\subsetneq S\land T\ne \varnothing)\quad(B) \]

这里的转移顺序是有讲究的,这可以理解成一个类似背包的 DP,按 \(S\)(二进制形态)升序枚举即可。

这两种转移具体如何实现呢?对于 \((B)\) 式较为简单,枚举子集即可,时间复杂度为 \(O(3^k n)\)

对于 \((A)\) 式,可以想到最短路的松弛。所以在 \((B)\) 式枚举子集后,在同一个 \(S\)\(n\) 个点跑 dij 即可,这部分时间复杂度为 \(O(2^km\log m)\)

所以总时间复杂度为 \(O(3^kn+2^km\log m)\)

P6192 【模板】最小斯坦纳树

//We'll be counting stars.
//#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
using namespace std;
#define fir first
#define sec second
#define mkp make_pair
#define pb emplace_back
#define For(i,j,k) for(int i=(j),i##_=(k);i<=i##_;i++)
#define Rof(i,j,k) for(int i=(j),i##_=(k);i>=i##_;i--)
#define ckmx(a,b) a=max(a,b)
#define ckmn(a,b) a=min(a,b)
#define debug(...) cerr<<"#"<<__LINE__<<": "<<__VA_ARGS__<<endl
#define N 101
#define V (1<<10)
#define pi pair<int,int>
const int inf=1e9;
int n,m,k,S,f[V][N];
vector<pi> e[N];
priority_queue<pi> q;
bool vis[N];
void dij(int *dis){
	fill(vis+1,vis+1+n,false);
	int x;
	while(!q.empty()){
		x=q.top().sec;
		q.pop();
		if(vis[x]) continue;
		vis[x]=true;
		for(auto i:e[x]){
			if(dis[i.fir]>dis[x]+i.sec){
				dis[i.fir]=dis[x]+i.sec;
				q.push(mkp(-dis[i.fir],i.fir));
			}
		}
	}
}
signed main(){ios::sync_with_stdio(false),cin.tie(nullptr);
	cin>>n>>m>>k;
	S=1<<k;
	int x,y,z;
	For(i,1,m){
		cin>>x>>y>>z;
		e[x].pb(mkp(y,z));
		e[y].pb(mkp(x,z));
	}
	For(i,1,S-1) fill(f[i]+1,f[i]+1+n,inf);
	For(i,1,k){
		cin>>x;
		f[1<<(i-1)][x]=0;
	}//dont change x after this (x for a representative)
	For(s,1,S-1){
		For(i,1,n){//f(S,i)<-f(T,i)+f(S^T,i) (T|S=S)
			for(int ss=s&(s-1);ss;ss=s&(ss-1))
				ckmn(f[s][i],f[ss][i]+f[s^ss][i]);
			if(f[s][i]<inf) q.push(mkp(-f[s][i],i));
		}
		dij(f[s]);//f(S,i)<-f(S,j)+w(i,j)
	}
	cout<<f[S-1][x]<<endl;
return 0;}

P4294 [WC2008]游览计划

这里是点权,所以我们的 DP 转移要改一下:

\[f(S,j)+a_i\to f(S,i) \]

\[f(T,i)+f(S-T,i)-a_i\to f(S,i)\quad(T\subsetneq S\land T\ne \varnothing) \]

后者 \(-a_i\) 的原因是被算了两遍。

一样做。

最后 DFS 求方案即可。

//We'll be counting stars.
//#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
using namespace std;
#define fir first
#define sec second
#define mkp make_pair
#define pb emplace_back
#define For(i,j,k) for(int i=(j),i##_=(k);i<=i##_;i++)
#define Rof(i,j,k) for(int i=(j),i##_=(k);i>=i##_;i--)
#define ckmx(a,b) a=max(a,b)
#define ckmn(a,b) a=min(a,b)
#define debug(...) cerr<<"#"<<__LINE__<<": "<<__VA_ARGS__<<endl
#define N 101
#define V (1<<10)
const int inf=1e9;
int n,m,lim,S,k=0,a[N],f[V][N],pre[V][N];
bool ans[N],vis[N],used[V][N];
vector<int> e[N];
priority_queue<pair<int,int> > q;
int num(int x,int y){ return (x-1)*m+y; }
void adde(int x,int y){
	e[x].pb(y);
	e[y].pb(x);
}
void dij(int *dis,int *p){
	fill(vis+1,vis+1+lim,false);
	int x;
	while(!q.empty()){
		x=q.top().sec;
		q.pop();
		if(vis[x]) continue;
		vis[x]=true;
		for(int i:e[x]){
			if(dis[i]>dis[x]+a[i]){
				dis[i]=dis[x]+a[i];
				q.push(mkp(-dis[i],i));
				p[i]=x;
			}
		}
	}
}
void dfs(int x,int y){
	if(used[x][y]) return ;
	used[x][y]=true;
	ans[y]=true;
	if(pre[x][y]){
		int tmp=pre[x][y];
		while(tmp){
			dfs(x,tmp);
			tmp=pre[x][tmp];
		}
	}else{
		for(int s=x&(x-1);s;s=x&(s-1)){
			if(f[x][y]==f[s][y]+f[x^s][y]-a[y]){
				dfs(s,y);
				dfs(x^s,y);
				break;
			}
		}
	}
}
signed main(){ios::sync_with_stdio(false),cin.tie(nullptr);
	cin>>n>>m;
	lim=n*m;
	For(i,1,lim) cin>>a[i];
	For(i,1,lim) if(!a[i]) k++;
	if(!k){
		cout<<0<<endl;
		For(i,1,n){
			For(j,1,m) cout<<"_";
			cout<<endl;
		}
		return 0;
	}
	S=1<<k;
	For(i,1,S-1) fill(f[i]+1,f[i]+lim+1,inf);
	int x=0,rep;
	For(i,1,n) For(j,1,m){
		if(!a[num(i,j)]) f[1<<(x++)][num(i,j)]=0,rep=num(i,j);
		if(i<n) adde(num(i,j),num(i+1,j));
		if(j<m) adde(num(i,j),num(i,j+1));
	}
	For(s,1,S-1){
		For(i,1,lim){
			for(int ss=s&(s-1);ss;ss=s&(ss-1))
				ckmn(f[s][i],f[s^ss][i]+f[ss][i]-a[i]);//double calced a[i] so subtract it
			if(f[s][i]<inf) q.push(mkp(-f[s][i],i));
		}
		dij(f[s],pre[s]);
	}
	cout<<f[S-1][rep]<<endl;
	dfs(S-1,rep);
	For(i,1,n){
		For(j,1,m){
			if(!a[num(i,j)]) cout<<"x";
			else if(ans[num(i,j)]) cout<<"o";
			else cout<<"_";
		}
		cout<<endl;
	}
return 0;}
posted @ 2022-08-17 14:54  ShaoJia  阅读(507)  评论(0编辑  收藏  举报