严格次小生成树
算法模型
给定一个有 \(n\) 个点、\(m\) 条边的无向图 \(G\),试求 \(G\) 的所有生成树中,边权总和 严格 次小的生成树。数据保证存在严格次小生成树,不保证 无向图中没有自环。
算法思路
有一个显然的结论:无向图 \(G\) 中至少有一棵次小生成树,与无向图 \(G\) 的最小生成树只有一条边的差异,具体证明略。
由此,我们可以想到一个最为浅显的思路:先求出 \(G\) 的最小生成树,再从最小生成树中删除一条边,并加入一条不在最小生成树中的边。得到的严格次小边权总和即为答案。时间复杂度为 \(O(nm)\),显然无法通过 \(n \leq 10^5\),\(m \leq 3 \times 10^5\) 的情况。
因此,我们可以尝试转换一下思路:先在原图的最小生成树上加入一条原本不在最小生成树上的边 \((u, v)\),再从原始最小生成树上结点 \(u\) 到结点 \(v\) 的路径中,删除一条边权最大的边。枚举每一条不在最小生成树上的边并进行上述操作,得到的值中最小的就是次小生成树的边权总和。
因为最小生成树会选出权值最小的令图连通的 \(n - 1\) 条边,所以不在最小生成树上的边,边权一定会大于 \(n - 1\) 条边中的一条边,由此再分类讨论可以得出结论:上述方法求出的次小生成树,边权总和一定 不小于 最小生成树。
但是,这样操作得出的结果是 不严格 次小生成树。为什么?因为最小生成树上,\(u\) 到 \(v\) 路径中的最大值可能等于加入的边的边权,也就是说,次小生成树的权值总和可能 等于 最小生成树的权值总和。
因此,我们不仅需要维护最小生成树上的路径最大值,还要维护 路径次大值。当路径最大值等于加入的边权时,就删去边权严格次大的那条边,使得次小生成树的边权总和一定 严格大于 最小生成树。
至于如何维护路径最大值和次大值,此处有多种方法,例如 树链剖分 、倍增等。综合代码长度和思维难度,此处笔者选用倍增维护,下文代码也是以倍增写法为准。
设 \(f_{u, i}\) 为 \(u\) 的第 \(2^i\) 辈祖先。显然,\(u\) 到 \(f_{u, i}\) 的路径最大值 \(value_{u, i} = value_{u, i - 1}, value_{f_{u, i - 1}, i - 1}\) 。路径次大值则分类讨论:首先,\(u\) 到第 \(2^i\) 辈祖先的路径次大值 \(val_{u, i} = max(val_{u, i - 1}, val_{f_{u, i - 1}, i - 1})\)。如果 \(value_{u, i - 1} < value_{f_{u, i - 1}, i - 1}\) ,此时因为 \(value_{u, i - 1} > val_{u, i - 1}, value_{u, i - 1} > val_{f_{u, i - 1}, i - 1}\),所以路径次大值 \(val_{u, i} = value_{u, i - 1}\)。反之同理,若 \(value_{u, i - 1} > value_{f_{u, i - 1}, i - 1}\),则 \(val_{u, i} = value_{f_{u, i - 1}, u - 1}\)。
最后,提几个容易写错的点:
-
使用两个 不同的 数组来存边,其中一个用于最小生成树,另一个用于倍增的 \(dfs\) ,原因显然。
-
开
long long
,inf
必须开大,笔者选用1e16 + 5
,一定不可以采用0x3f3f3f3f
或2147483647
。 -
读入的边存在最小生成树用的数组中,另一个存边数组 只存储在最小生成树上的边,原因显然。
参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 5;
const int maxm = 6e5 + 5;
const long long inf = 1e16 + 5;
struct node
{
int u, v;
long long w;
} edge[maxm];
struct Edge
{
int to, nxt;
long long w;
} g[maxm];
int n, m, cnt;
int head[maxn], fa[maxn], dep[maxn], f[maxn][20];
long long value[maxn][20], val[maxn][20];
bool vis[maxm];
bool cmp(node a, node b)
{
return a.w < b.w;
}
void add_edge(int u, int v, long long w)
{
cnt++;
g[cnt].to = v;
g[cnt].w = w;
g[cnt].nxt = head[u];
head[u] = cnt;
}
void init()
{
for (int i = 1; i <= n; i++)
fa[i] = i;
}
int get(int x)
{
if (fa[x] == x)
return x;
return fa[x] = get(fa[x]);
}
void merge(int x, int y)
{
x = get(x);
y = get(y);
if (x != y)
fa[y] = x;
}
long long kruskal()
{
long long sum = 0;
for (int i = 1; i <= m; i++)
{
int fu = get(edge[i].u);
int fv = get(edge[i].v);
if (fu != fv)
{
vis[i] = true;
merge(fu, fv);
sum += edge[i].w;
add_edge(edge[i].u, edge[i].v, edge[i].w);
add_edge(edge[i].v, edge[i].u, edge[i].w);
}
}
return sum;
}
void dfs(int u)
{
dep[u] = dep[f[u][0]] + 1;
for (int i = head[u]; i; i = g[i].nxt)
{
if (g[i].to != f[u][0])
{
f[g[i].to][0] = u;
value[g[i].to][0] = g[i].w;
val[g[i].to][0] = -inf;
dfs(g[i].to);
}
}
}
int lca(int x, int y)
{
if (dep[x] < dep[y])
swap(x, y);
for (int i = 18; i >= 0; i--)
if (dep[f[x][i]] >= dep[y])
x = f[x][i];
if (x == y)
return x;
for (int i = 18; i >= 0; i--)
{
if (f[x][i] != f[y][i])
{
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
void calc()
{
for (int j = 1; j <= 18; j++)
{
for (int i = 1; i <= n; i++)
{
f[i][j] = f[f[i][j - 1]][j - 1];
value[i][j] = max(value[i][j - 1], value[f[i][j - 1]][j - 1]);
val[i][j] = max(val[i][j - 1], val[f[i][j - 1]][j - 1]);
if (value[i][j - 1] < value[f[i][j - 1]][j - 1])
val[i][j] = max(val[i][j], value[i][j - 1]);
else if (value[i][j - 1] > value[f[i][j - 1]][j - 1])
val[i][j] = max(val[i][j], value[f[i][j - 1]][j - 1]);
}
}
}
long long query(int u, int v, int w)
{
long long ret = -inf;
for (int i = 18; i >= 0; i--)
{
if (dep[f[u][i]] >= dep[v])
{
if (w != value[u][i])
ret = max(ret, value[u][i]);
else
ret = max(ret, val[u][i]);
u = f[u][i];
}
}
return ret;
}
int main()
{
int u, v;
long long w;
scanf("%d%d", &n, &m);
init();
for (int i = 1; i <= m; i++)
scanf("%d%d%lld", &edge[i].u, &edge[i].v, &edge[i].w);
sort(edge + 1, edge + m + 1, cmp);
long long sum = kruskal(), ans = inf, q;
dfs(1);
calc();
for (int i = 1; i <= m; i++)
{
if (!vis[i])
{
int l = lca(edge[i].u, edge[i].v);
q = max(query(edge[i].u, l, edge[i].w), query(edge[i].v, l, edge[i].w));
ans = min(ans, sum - q + edge[i].w);
}
}
printf("%lld\n", ans);
return 0;
}