严格次小生成树
思路
我们首先求出最小生成树,可知次小生成树与最小生成树只有一边之差(恶心)。
证明
我们先求一个最小生成树。
然后,我们加入一条没有使用过的边加入,就会形成一个环。
我们设这个环上最大值为\(Val_1\),次大值为\(Val_2\)。
所以我们\(Val_1>Val_2\)。
我们将加入的边替换掉\(Val_1\)时
ans=ans-val1+z;
其中\(z\)为多余边。
但当我们的\(Val_1\)等于加入的边时,我们就要替换\(Val_1\)不是答案,我们就要替换\(Val_2\)。
我们将加入的边替换掉\(Val_2\)时
ans=ans-val2+z;
我们把得出的值求出最小值就可以了。
但我们有一个很严肃的问题:我们这么求最小值和次小值?
我们就使用\(LCA\)来计算路径上的最大值和次大值。
代码
#include <bits/stdc++.h>
using namespace std;
int n,m,fa[100005],head[100005],ver[200005],edge[200005],next[200005],cnt,ans=0x3f3f3f3f;
long long sum;
struct node{
int u,v,w;
bool flag;
node(){ flag=u=v=w=0; }
bool operator <(const node &x) const { return w<x.w; }
}a[300005];
int find(int x) {
if(fa[x]==x) return fa[x];
return fa[x]=find(fa[x]);
}
void add_edge(int u,int v,int w) { ver[++cnt]=v,edge[cnt]=w,next[cnt]=head[u],head[u]=cnt; }
void Kurskal() {
int k=1;
for(int i=1;i<=n;i++) fa[i]=i;
sort(a+1,a+m+1);
for(int i=1;k<n;i++) {
int s1=find(a[i].u),s2=find(a[i].v);
if(s1!=s2) {
fa[s1]=s2;
a[i].flag=1;
add_edge(a[i].u,a[i].v,a[i].w),add_edge(a[i].v,a[i].u,a[i].w);
k++,sum+=a[i].w;
}
}
return ;
}
int dep[100005],f[100005][25],maxx[100005][25],smax[100005][25];
void dfs(int x) {
for(int i=0;f[x][i];i++) {
f[x][i+1]=f[f[x][i]][i];
maxx[x][i+1]=max(maxx[x][i],maxx[f[x][i]][i]);
if(maxx[x][i]==maxx[f[x][i]][i]) smax[x][i+1]=max(smax[x][i],smax[f[x][i]][i]);
else smax[x][i+1]=max(min(maxx[x][i],maxx[f[x][i]][i]),max(smax[x][i],smax[f[x][i]][i]));
}
for(int i=head[x];i;i=next[i]) {
int v=ver[i];
if(v!=f[x][0]) {
dep[v]=dep[x]+1,maxx[v][0]=edge[i],smax[v][0]=-1,f[v][0]=x;
dfs(v);
}
}
return ;
}
int LCA(int a,int b){
if(dep[a]<dep[b]) swap(a,b);
int d=dep[a]-dep[b];
for(int i=0;d;i++,d>>=1)
if(d & 1) a=f[a][i];
if(a==b) return a;
for(int i=20;i>=0;i--)
if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i];
return f[a][0];
}
void work(int u,int v,int w) {
int d=dep[u]-dep[v];
int m1=0,m2=0;
for(int i=0;d;i++,d>>=1) {
if(d & 1) {
m2=max(m2,smax[u][i]);
if(maxx[u][i]>m1) {
m2=max(m2,m1);
m1=maxx[u][i];
}
}
}
if(m1==w) ans=min(ans,w-m2);
else ans=min(ans,w-m1);
}
int main() {
scanf("%d %d",&n,&m);
for(int i=1;i<=m;i++) scanf("%d %d %d",&a[i].u,&a[i].v,&a[i].w);
Kurskal();
dfs(1);
for(int i=1;i<=m;i++) {
if(a[i].flag) continue;
int u=a[i].u,v=a[i].v;
int lca=LCA(u,v);
work(u,lca,a[i].w),work(v,lca,a[i].w);
}
printf("%lld",sum+ans);
return 0;
}