AT3611 Tree MST

题面

题解

考虑最小化\(dis(x, y)\)

这里需要对一种奇怪的最小生成树算法:Boruvka算法有深刻的理解。

考虑该算法的执行过程,我们可以考虑进行点分治,每次找到离分治重心最近的点,然后将分治重心的所有子树的点全部向这个点连边,边数是\(\mathrm{O}(\)子树大小\()\)的,所以总边数在\(\mathrm{O}(n\log_2n)\)级别,最后将这些边跑kruskal求出最小生成树就可以了,总复杂度\(\mathrm{O}(n\log_2^2 n)\)

代码

#include<cstdio>
#include<cstring>
#include<cctype>
#include<climits>
#include<algorithm>
#define RG register

inline int read()
{
	int data = 0, w = 1; char ch = getchar();
	while(ch != '-' && (!isdigit(ch))) ch = getchar();
	if(ch == '-') w = -1, ch = getchar();
	while(isdigit(ch)) data = data * 10 + (ch ^ 48), ch = getchar();
	return data * w;
}

const int maxn(2e5 + 10);
struct edge { int next, to, dis; } e[maxn << 1];
int head[maxn], e_num;
inline void add_edge(int from, int to, int dis)
{
	e[++e_num] = (edge) {head[from], to, dis};
	head[from] = e_num;
}

struct node { int x, y; long long w; } a[maxn * 50];
inline int cmp(const node &lhs, const node &rhs) { return lhs.w < rhs.w; }
int n, m, W[maxn], SIZE, size[maxn], root, _max, vis[maxn];
void getRoot(int x, int fa)
{
	int max = 0; size[x] = 1;
	for(RG int i = head[x]; i; i = e[i].next)
	{
		int to = e[i].to; if(vis[to] || to == fa) continue;
		getRoot(to, x); size[x] += size[to]; max = std::max(max, size[to]);
	}
	max = std::max(max, SIZE - size[x]);
	if(max < _max) _max = max, root = x;
}

int pos; long long val;
void dfs(int x, int fa, long long dep)
{
	if(dep + W[x] < val) val = dep + W[x], pos = x;
	for(RG int i = head[x]; i; i = e[i].next)
	{
		int to = e[i].to; if(vis[to] || to == fa) continue;
		dfs(to, x, dep + e[i].dis);
	}
}

void link(int x, int fa, long long dep)
{
	a[++m] = (node) {x, pos, val + W[x] + dep};
	for(RG int i = head[x]; i; i = e[i].next)
	{
		int to = e[i].to; if(vis[to] || to == fa) continue;
		link(to, x, dep + e[i].dis);
	}
}

void solve(int x)
{
	vis[x] = 1; val = LLONG_MAX >> 1; pos = 0;
	dfs(x, 0, 0); link(x, 0, 0);
	for(RG int i = head[x]; i; i = e[i].next)
	{
		int to = e[i].to; if(vis[to]) continue;
		SIZE = _max = size[to]; getRoot(to, x);
		solve(root);
	}
}

long long ans; int fa[maxn];
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }

int main()
{
	SIZE = _max = n = read();
	for(RG int i = 1; i <= n; i++) W[i] = read();
	for(RG int i = 1, x, y, z; i < n; i++)
		x = read(), y = read(), z = read(),
		add_edge(x, y, z), add_edge(y, x, z);
	getRoot(1, 0); solve(root);
	std::sort(a + 1, a + m + 1, cmp);
	for(RG int i = 1; i <= n; i++) fa[i] = i;
	for(RG int i = 1; i <= m; i++)
	{
		if(find(a[i].x) == find(a[i].y)) continue;
		fa[find(a[i].x)] = find(a[i].y); ans += a[i].w;
	}
	printf("%lld\n", ans);
	return 0;
}
posted @ 2019-03-01 17:27  xgzc  阅读(186)  评论(5编辑  收藏  举报