Loading

严格次小生成树

算法模型

例题链接

给定一个有 \(n\) 个点、\(m\) 条边的无向图 \(G\),试求 \(G\) 的所有生成树中,边权总和 严格 次小的生成树。数据保证存在严格次小生成树,不保证 无向图中没有自环。

算法思路

有一个显然的结论:无向图 \(G\) 中至少有一棵次小生成树,与无向图 \(G\) 的最小生成树只有一条边的差异,具体证明略。

由此,我们可以想到一个最为浅显的思路:先求出 \(G\) 的最小生成树,再从最小生成树中删除一条边,并加入一条不在最小生成树中的边。得到的严格次小边权总和即为答案。时间复杂度为 \(O(nm)\),显然无法通过 \(n \leq 10^5\)\(m \leq 3 \times 10^5\) 的情况。

因此,我们可以尝试转换一下思路:先在原图的最小生成树上加入一条原本不在最小生成树上的边 \((u, v)\),再从原始最小生成树上结点 \(u\) 到结点 \(v\) 的路径中,删除一条边权最大的边。枚举每一条不在最小生成树上的边并进行上述操作,得到的值中最小的就是次小生成树的边权总和。

因为最小生成树会选出权值最小的令图连通的 \(n - 1\) 条边,所以不在最小生成树上的边,边权一定会大于 \(n - 1\) 条边中的一条边,由此再分类讨论可以得出结论:上述方法求出的次小生成树,边权总和一定 不小于 最小生成树。

但是,这样操作得出的结果是 不严格 次小生成树。为什么?因为最小生成树上,\(u\)\(v\) 路径中的最大值可能等于加入的边的边权,也就是说,次小生成树的权值总和可能 等于 最小生成树的权值总和。

因此,我们不仅需要维护最小生成树上的路径最大值,还要维护 路径次大值。当路径最大值等于加入的边权时,就删去边权严格次大的那条边,使得次小生成树的边权总和一定 严格大于 最小生成树。

至于如何维护路径最大值和次大值,此处有多种方法,例如 树链剖分倍增等。综合代码长度和思维难度,此处笔者选用倍增维护,下文代码也是以倍增写法为准。

\(f_{u, i}\)\(u\) 的第 \(2^i\) 辈祖先。显然,\(u\)\(f_{u, i}\) 的路径最大值 \(value_{u, i} = value_{u, i - 1}, value_{f_{u, i - 1}, i - 1}\) 。路径次大值则分类讨论:首先,\(u\) 到第 \(2^i\) 辈祖先的路径次大值 \(val_{u, i} = max(val_{u, i - 1}, val_{f_{u, i - 1}, i - 1})\)。如果 \(value_{u, i - 1} < value_{f_{u, i - 1}, i - 1}\) ,此时因为 \(value_{u, i - 1} > val_{u, i - 1}, value_{u, i - 1} > val_{f_{u, i - 1}, i - 1}\),所以路径次大值 \(val_{u, i} = value_{u, i - 1}\)。反之同理,若 \(value_{u, i - 1} > value_{f_{u, i - 1}, i - 1}\),则 \(val_{u, i} = value_{f_{u, i - 1}, u - 1}\)

最后,提几个容易写错的点:

  1. 使用两个 不同的 数组来存边,其中一个用于最小生成树,另一个用于倍增的 \(dfs\) ,原因显然。

  2. long longinf 必须开大,笔者选用 1e16 + 5,一定不可以采用 0x3f3f3f3f2147483647

  3. 读入的边存在最小生成树用的数组中,另一个存边数组 只存储在最小生成树上的边,原因显然。

参考代码

#include <cstdio>
#include <algorithm>
using namespace std;

const int maxn = 1e5 + 5;
const int maxm = 6e5 + 5;
const long long inf = 1e16 + 5;

struct node
{
	int u, v;
	long long w;
} edge[maxm];

struct Edge
{
	int to, nxt;
	long long w;
} g[maxm];

int n, m, cnt;
int head[maxn], fa[maxn], dep[maxn], f[maxn][20];
long long value[maxn][20], val[maxn][20];
bool vis[maxm];

bool cmp(node a, node b)
{
	return a.w < b.w;
}

void add_edge(int u, int v, long long w)
{
	cnt++;
	g[cnt].to = v;
	g[cnt].w = w;
	g[cnt].nxt = head[u];
	head[u] = cnt;
}

void init()
{
	for (int i = 1; i <= n; i++)
		fa[i] = i;
}

int get(int x)
{
	if (fa[x] == x)
		return x;
	return fa[x] = get(fa[x]);
}

void merge(int x, int y)
{
	x = get(x);
	y = get(y);
	if (x != y)
		fa[y] = x;
}

long long kruskal()
{
	long long sum = 0;
	for (int i = 1; i <= m; i++)
	{
		int fu = get(edge[i].u);
		int fv = get(edge[i].v);
		if (fu != fv)
		{
			vis[i] = true;
			merge(fu, fv);
			sum += edge[i].w;
			add_edge(edge[i].u, edge[i].v, edge[i].w);
			add_edge(edge[i].v, edge[i].u, edge[i].w);
		}
	}
	return sum;
}

void dfs(int u)
{
	dep[u] = dep[f[u][0]] + 1;
	for (int i = head[u]; i; i = g[i].nxt)
	{
		if (g[i].to != f[u][0])
		{
			f[g[i].to][0] = u;
			value[g[i].to][0] = g[i].w;
			val[g[i].to][0] = -inf;
			dfs(g[i].to);
		}
	}
}

int lca(int x, int y)
{
	if (dep[x] < dep[y])
		swap(x, y);
	for (int i = 18; i >= 0; i--)
		if (dep[f[x][i]] >= dep[y])
			x = f[x][i];
	if (x == y)
		return x;
	for (int i = 18; i >= 0; i--)
	{
		if (f[x][i] != f[y][i])
		{
			x = f[x][i];
			y = f[y][i];
		}
	}
	return f[x][0];
}

void calc()
{
	for (int j = 1; j <= 18; j++)
	{
		for (int i = 1; i <= n; i++)
		{
			f[i][j] = f[f[i][j - 1]][j - 1];
			value[i][j] = max(value[i][j - 1], value[f[i][j - 1]][j - 1]);
			val[i][j] = max(val[i][j - 1], val[f[i][j - 1]][j - 1]);
			if (value[i][j - 1] < value[f[i][j - 1]][j - 1])
				val[i][j] = max(val[i][j], value[i][j - 1]);
			else if (value[i][j - 1] > value[f[i][j - 1]][j - 1])
				val[i][j] = max(val[i][j], value[f[i][j - 1]][j - 1]);
		}
	}
}

long long query(int u, int v, int w)
{
	long long ret = -inf;
	for (int i = 18; i >= 0; i--)
	{
		if (dep[f[u][i]] >= dep[v])
		{
			if (w != value[u][i])
				ret = max(ret, value[u][i]);
			else
				ret = max(ret, val[u][i]);
			u = f[u][i];
		}
	}
	return ret;
}

int main()
{
	int u, v;
	long long w;
	scanf("%d%d", &n, &m);
	init();
	for (int i = 1; i <= m; i++)
		scanf("%d%d%lld", &edge[i].u, &edge[i].v, &edge[i].w);
	sort(edge + 1, edge + m + 1, cmp);
	long long sum = kruskal(), ans = inf, q;
	dfs(1);
	calc();
	for (int i = 1; i <= m; i++)
	{
		if (!vis[i])
		{
			int l = lca(edge[i].u, edge[i].v);
			q = max(query(edge[i].u, l, edge[i].w), query(edge[i].v, l, edge[i].w));
			ans = min(ans, sum - q + edge[i].w);
		}
	}
	printf("%lld\n", ans);
	return 0;
}
posted @ 2021-07-24 23:25  kymru  阅读(62)  评论(0编辑  收藏  举报