最小斯坦纳树
什么鬼名字,不如叫做扩展最小生成树。
定义
规范定义像是专门不让人看懂来提高算法高级度的,这里说一说比较易懂的定义。
先想最小生成树,它问的是使图上所有点联通的最小边权和。那么如果把所有点改为指定点,并允许其他点存在,那么这就是所谓的最小斯坦纳树。
求法
它的求法和最小生成树没有半毛钱关系,反而要用到最短路。
考虑 dp ,根据题目数据范围容易想到状压 dp,设 \(dp_{mask,i}\) 表示以 \(i\) 节点为当前斯坦纳树的根,应选择的点的状态为 \(mask\) 时的最小代价。
初值先全部赋成极大值,对于每个特殊点让每个对应的状态变成 \(0\)。
转移的时候分两种情况,首先考虑不换根的情况,正常状压,方程如下:
\[dp_{mask,i}=\min_{s\in mask}(dp_{s,i}+dp_{mask\oplus x,i})
\]
枚举所有子集转移。子集枚举是小技巧。
考虑换根的情况,方程为:
\[dp_{mask,i}=\min(dp_{mask,j}+v_{i,j})
\]
转移的时候用最短路转移。
Code
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=998244353;
inline int read()
{
int w=1,s=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-') w=-1;ch=getchar();}
while(isdigit(ch)){s=s*10+(ch-'0');ch=getchar();}
return w*s;
}
inline ll read_()
{
ll w=1;ll s=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-') w=-1;ch=getchar();}
while(isdigit(ch)){s=s*10ll+(ch-'0');ch=getchar();}
return w*s;
}
const int maxn=1e5+10;
int n,m,k;
ll ans=1e18;
struct no
{
int y,v;
};
struct dij
{
int y;
ll d;
inline friend bool operator < (dij x,dij y)
{
return x.d>y.d;
}
};
priority_queue<dij> q;
bool vis[maxn];
vector<no> G[maxn];
int c[maxn];
ll dp[35][maxn];
void dijkstra(int s)
{
for(int i=0;i<=n;i++)vis[i]=0;
while(!q.empty())
{
int u=q.top().y;
q.pop();
if(vis[u])continue;
vis[u]=1;
for(auto i : G[u])
{
int y=i.y;ll v=i.v;
if(dp[s][y]>dp[s][u]+v)
{
dp[s][y]=dp[s][u]+v;
q.push({y,dp[s][y]});
}
// cout<<dp[s][y]<<' ';
}
// cout<<endl;
}
}
signed main()
{
cin>>n>>k>>m;
for(int i=0;i<=n;i++)
{
for(int mask=0;mask<=(1<<k);mask++)
{
dp[mask][i]=1e18;
}
}
for(int i=1;i<=k;i++)
{
int x=read();
dp[1<<(i-1)][x]=0;
}
for(int i=1;i<=m;i++)
{
int x=read(),y=read();ll v=read_();
G[x].push_back({y,v});
G[y].push_back({x,v});
}
for(int mask=1;mask<(1<<k);mask++)
{
for(int i=1;i<=n;i++)
{
for(int j=mask&(mask-1);j;j=(j-1)&mask)
{
dp[mask][i]=min(dp[mask][i],dp[j][i]+dp[mask^j][i]);
}
if(dp[mask][i]<dp[0][0])
{
q.push({i,dp[mask][i]});
}
}
dijkstra(mask);
// for(int j=1;j<=n;j++)
// {
// cout<<dp[j][mask]<<' ';;
// }cout<<endl;
}
for(int i=1;i<=n;i++)ans=min(ans,dp[(1<<k)-1][i]);
cout<<ans;
return 0;
}