CODE FESTIVAL 2017 Final J Tree MST
求完全图的最小生成树,立刻想到 Boruvka。
于是剩下的任务是,对于每个点 \(y\),找到当前和它不在同一连通块的点 \(y\) 的 \(F(x, y) = w_y + dis_{x, y}\) 的最小值。
如果没有 \(x, y\) 所在连通块不同的限制,可以很轻易地换根 dp 完成。先自下而上求出 \(y\) 在子树内的 \(F(x, y)\) 最小值,再自上而下求出 \(y\) 在子树外 \(F(x, y)\) 最小值。
加上了这个限制,我们除了求每个 \(x\) 的 \(F(x, y)\) 最小值和它对应的 \(y\),还要求次小值和它对应的 \(y\)。需要注意我们强制规定最小值和次小值对应的 \(y\) 当前所在连通块不同。这样如果 \(x\) 跟最小值的 \(y\) 在同一连通块,就可以让次小值递补。
时间复杂度 \(O(n \log n)\)。
code
// Problem: J - Tree MST
// Contest: AtCoder - CODE FESTIVAL 2017 Final
// URL: https://atcoder.jp/contests/cf17-final/tasks/cf17_final_j
// Memory Limit: 256 MB
// Time Limit: 5000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
ll n, a[maxn], head[maxn], len, fa[maxn];
pii f[maxn][2], g[maxn][2], b[maxn];
struct edge {
int to, dis, next;
} edges[maxn << 1];
inline void add_edge(int u, int v, int d) {
edges[++len].to = v;
edges[len].dis = d;
edges[len].next = head[u];
head[u] = len;
}
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
inline bool merge(int x, int y) {
x = find(x);
y = find(y);
if (x != y) {
fa[x] = y;
return 1;
} else {
return 0;
}
}
inline void upd(pii a, pii &x, pii &y) {
if (a < x) {
if (find(a.scd) != find(x.scd)) {
y = x;
}
x = a;
} else if (a < y) {
if (find(a.scd) != find(x.scd)) {
y = a;
}
}
}
void dfs(int u, int fa) {
f[u][0] = make_pair(a[u], u);
f[u][1] = make_pair(1e18, -1);
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to, d = edges[i].dis;
if (v == fa) {
continue;
}
dfs(v, u);
pii p1 = f[v][0], p2 = f[v][1];
p1.fst += d;
p2.fst += d;
upd(p1, f[u][0], f[u][1]);
upd(p2, f[u][0], f[u][1]);
}
}
void dfs2(int u, int fa, pii p1, pii p2) {
g[u][0] = f[u][0];
g[u][1] = f[u][1];
upd(p1, g[u][0], g[u][1]);
upd(p2, g[u][0], g[u][1]);
vector<int> son, dis;
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to, d = edges[i].dis;
if (v == fa) {
continue;
}
son.pb(v);
dis.pb(d);
}
if (son.empty()) {
return;
}
int len = (int)son.size();
vector< vector<pii> > pre(len, vector<pii>(2)), suf(len, vector<pii>(2));
pre[0][0] = f[son[0]][0];
pre[0][0].fst += dis[0];
pre[0][1] = f[son[0]][1];
pre[0][1].fst += dis[0];
for (int i = 1; i < len; ++i) {
pre[i][0] = pre[i - 1][0];
pre[i][1] = pre[i - 1][1];
int v = son[i], d = dis[i];
pii t = f[v][0];
t.fst += d;
upd(t, pre[i][0], pre[i][1]);
t = f[v][1];
t.fst += d;
upd(t, pre[i][0], pre[i][1]);
}
suf[len - 1][0] = f[son[len - 1]][0];
suf[len - 1][0].fst += dis[len - 1];
suf[len - 1][1] = f[son[len - 1]][1];
suf[len - 1][1].fst += dis[len - 1];
for (int i = len - 2; ~i; --i) {
suf[i][0] = suf[i + 1][0];
suf[i][1] = suf[i + 1][1];
int v = son[i], d = dis[i];
pii t = f[v][0];
t.fst += d;
upd(t, suf[i][0], suf[i][1]);
t = f[v][1];
t.fst += d;
upd(t, suf[i][0], suf[i][1]);
}
for (int i = 0; i < len; ++i) {
int v = son[i], d = dis[i];
pii q1 = p1, q2 = p2, t = make_pair(d + a[u], u);
q1.fst += d;
q2.fst += d;
upd(t, q1, q2);
if (i) {
t = pre[i - 1][0];
t.fst += d;
upd(t, q1, q2);
t = pre[i - 1][1];
t.fst += d;
upd(t, q1, q2);
}
if (i + 1 < len) {
t = suf[i + 1][0];
t.fst += d;
upd(t, q1, q2);
t = suf[i + 1][1];
t.fst += d;
upd(t, q1, q2);
}
dfs2(v, u, q1, q2);
}
}
void solve() {
scanf("%lld", &n);
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
fa[i] = i;
}
for (int i = 1, u, v, d; i < n; ++i) {
scanf("%d%d%d", &u, &v, &d);
add_edge(u, v, d);
add_edge(v, u, d);
}
ll ans = 0;
while (1) {
int cnt = 0;
for (int i = 1; i <= n; ++i) {
cnt += (fa[i] == i);
}
if (cnt == 1) {
break;
}
dfs(1, -1);
dfs2(1, -1, make_pair(1e18, -1), make_pair(1e18, -1));
for (int i = 1; i <= n; ++i) {
b[i] = make_pair(1e18, -1);
}
for (int i = 1; i <= n; ++i) {
pii x = g[i][0], y = g[i][1];
x.fst += a[i];
y.fst += a[i];
if (find(i) == find(x.scd)) {
b[find(i)] = min(b[find(i)], y);
} else {
b[find(i)] = min(b[find(i)], x);
}
}
for (int i = 1; i <= n; ++i) {
if (fa[i] == i && merge(i, b[i].scd)) {
ans += b[i].fst;
}
}
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}