AT3611 Tree MST
题面
题解
考虑最小化\(dis(x, y)\)
这里需要对一种奇怪的最小生成树算法:Boruvka算法有深刻的理解。
考虑该算法的执行过程,我们可以考虑进行点分治,每次找到离分治重心最近的点,然后将分治重心的所有子树的点全部向这个点连边,边数是\(\mathrm{O}(\)子树大小\()\)的,所以总边数在\(\mathrm{O}(n\log_2n)\)级别,最后将这些边跑kruskal求出最小生成树就可以了,总复杂度\(\mathrm{O}(n\log_2^2 n)\)。
代码
#include<cstdio>
#include<cstring>
#include<cctype>
#include<climits>
#include<algorithm>
#define RG register
inline int read()
{
int data = 0, w = 1; char ch = getchar();
while(ch != '-' && (!isdigit(ch))) ch = getchar();
if(ch == '-') w = -1, ch = getchar();
while(isdigit(ch)) data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}
const int maxn(2e5 + 10);
struct edge { int next, to, dis; } e[maxn << 1];
int head[maxn], e_num;
inline void add_edge(int from, int to, int dis)
{
e[++e_num] = (edge) {head[from], to, dis};
head[from] = e_num;
}
struct node { int x, y; long long w; } a[maxn * 50];
inline int cmp(const node &lhs, const node &rhs) { return lhs.w < rhs.w; }
int n, m, W[maxn], SIZE, size[maxn], root, _max, vis[maxn];
void getRoot(int x, int fa)
{
int max = 0; size[x] = 1;
for(RG int i = head[x]; i; i = e[i].next)
{
int to = e[i].to; if(vis[to] || to == fa) continue;
getRoot(to, x); size[x] += size[to]; max = std::max(max, size[to]);
}
max = std::max(max, SIZE - size[x]);
if(max < _max) _max = max, root = x;
}
int pos; long long val;
void dfs(int x, int fa, long long dep)
{
if(dep + W[x] < val) val = dep + W[x], pos = x;
for(RG int i = head[x]; i; i = e[i].next)
{
int to = e[i].to; if(vis[to] || to == fa) continue;
dfs(to, x, dep + e[i].dis);
}
}
void link(int x, int fa, long long dep)
{
a[++m] = (node) {x, pos, val + W[x] + dep};
for(RG int i = head[x]; i; i = e[i].next)
{
int to = e[i].to; if(vis[to] || to == fa) continue;
link(to, x, dep + e[i].dis);
}
}
void solve(int x)
{
vis[x] = 1; val = LLONG_MAX >> 1; pos = 0;
dfs(x, 0, 0); link(x, 0, 0);
for(RG int i = head[x]; i; i = e[i].next)
{
int to = e[i].to; if(vis[to]) continue;
SIZE = _max = size[to]; getRoot(to, x);
solve(root);
}
}
long long ans; int fa[maxn];
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
int main()
{
SIZE = _max = n = read();
for(RG int i = 1; i <= n; i++) W[i] = read();
for(RG int i = 1, x, y, z; i < n; i++)
x = read(), y = read(), z = read(),
add_edge(x, y, z), add_edge(y, x, z);
getRoot(1, 0); solve(root);
std::sort(a + 1, a + m + 1, cmp);
for(RG int i = 1; i <= n; i++) fa[i] = i;
for(RG int i = 1; i <= m; i++)
{
if(find(a[i].x) == find(a[i].y)) continue;
fa[find(a[i].x)] = find(a[i].y); ans += a[i].w;
}
printf("%lld\n", ans);
return 0;
}