最小斯坦纳树
最小斯坦纳树
斯坦纳树问题是组合优化问题,与 最小生成树相似 ,是最短网络的一种。最小生成树是在给定的点集和边中寻求最短网络使所有点连通。而最小斯坦纳树允许在给定点外增加额外的点,使生成的最短网络开销最小。
我们所熟知的最小生成树问题其实就是一种特殊的最小斯坦纳树问题,相当于在所有点之间找到一个最短网络。
实现
题目描述:
给定一个包含 \(n\) 个结点和 \(m\) 条带权边的无向连通图 \(G=(V,E)\)。
再给定包含 \(k\) 个结点的点集 \(S\),选出 \(G\) 的子图 \(G'=(V',E')\),使得:
- \(S\subseteq V'\);
- \(G'\) 为连通图
- \(E'\)中所有边的权值和最小
求出 \(E'\) 中所有边的权值和
我们通过子集dp求解最小斯坦纳树
一种错误的状态是直接用 \(dp_s\) 表示当前把 \(s\) 中的点加入了最小斯坦纳树,因为转移的时候可能会出现集合不相交的情况,所以再加一维中转点变成 \(dp_{x,s}\),转移有 \(dp_{x,s} = min\{dp_{x,t},dp_{x,s\ xor\ t}\}\),这样可以实现不同最小斯坦纳树集合间的拼接。
但是拼接后不同的中转点可能有不同的贡献,考虑中转点之间的转移,枚举两个中转点有 \(dp_{x,s} = min_{y\to x}\{dp_{y,s} + w(y,x)\}\),但是这样有后效性,用 SPFA 转移可以解决。
代码实现
#include <bits/stdc++.h>
#define ll long long
#define mod 1e9+7
const int MAXN = 101010;
const int inf = 2e9;
using namespace std;
struct edge{
int u, v, w, nxt;
} e[MAXN];
struct node{
int p, w;
bool operator< ( const node x ) const{
return w > x.w;
}
};
int head[MAXN], cnt = 1;
int f[200][2048];
int dis[200], vis[200];
int n, m, k;
priority_queue < node > q;
inline int read( ){
int x = 0 ; short w = 0 ; char ch = 0;
while( !isdigit(ch) ) { w|=ch=='-';ch=getchar();}
while( isdigit(ch) ) {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return w ? -x : x;
}
void add( int u, int v, int w ){
e[++cnt] = (edge){ u, v, w, head[u] };
head[u] = cnt;
}
void dij( int s ){
memset( vis, 0, sizeof( vis ) );
for( int i = 1; i <= n; i++ ){
dis[i] = f[i][s];
if( dis[i] < inf ) q.push((node){i, dis[i]});
}
while( !q.empty( ) ){
int x = q.top( ).p;
q.pop( );
if( vis[x] ) continue;
vis[x] = 1;
for( int i = head[x]; i; i = e[i].nxt ){
int y = e[i].v;
if( dis[x] + e[i].w < dis[y] ){
dis[y] = dis[x] + e[i].w;
q.push((node){y, dis[y]});
}
}
}
for( int i = 1; i <= n; i++ )
f[i][s] = dis[i];
}
int main( ){
memset( f, 0x3f, sizeof( f ) );
n = read( ); m = read( ); k = read( );
for( int i = 1; i <= m; i++ ){
int x = read( ), y = read( ), z = read( );
add( x, y, z );
add( y, x, z );
}
for( int i = 1; i <= k; i++ ){
int x = read( );
f[x][1 << (i - 1)] = 0;
}
for( int i = 1; i <= n; i++ ) f[i][0] = 0;
for( int s = 0; s < ( 1 << k ); s++ ){
for( int i = 1; i <= n; i++ )
for( int t = s & ( s - 1 ); t; t = s & ( t - 1 ) )
f[i][s] = min( f[i][s], f[i][t] + f[i][t ^ s] );
dij( s );
}
int ans = inf;
for( int i = 1; i <= n; i++ )
ans = min( ans, f[i][(1 << k) - 1] );
cout << ans;
return 0;
}