【题解】AT3611 Tree MST
喝,长大了🎵↑↓↑↓
思路
点分治 + kruskal.
对于完全图 MST 问题有一个结论:
\(\forall G = (V, E)\),若 \(E_1 \cup ... \cup E_m = E\),设 \(E_i\) 经过 kruskal 得到的有效边集为 \(E^{\prime}_i\),则对 \(E^{\prime}_1 \cup ... \cup E^{\prime}_m\) 进行 kruskal 得到的是 \(G\) 的 MST。
又发现题目考虑的是树上路径问题,所以想到点分治。
因为点分治可以考虑到树上的每一条路径,它们取并显然可以覆盖所有边,所以可以直接在点分的时候求当前子树的 MST。
因为点分的时候钦定子树重心为根,所以树上路径可以换一种形式写。
令 \(dep_u\) 表示从当前分治的根结点到点 \(u\) 的距离,那么在 \((u, v)\) 之间连边的代价是 \(w_u + w_v + dep_u + dep_v\)
可以把贡献拆开,\(w_u + dep_u\) 赋给 \(u\),\(v\) 同理。
那么只需要在子树中找出 \(w_u + dep_u\) 最小的点 \(u\),然后子树内的其他所有点都向点 \(u\) 连边即可。
那么只需要进行一遍点分就可以找出所有有用的点,直接 kruskal 一次就行。
点分的复杂度是 \(O(n \log n)\),所以至多有 \(O(n \log n)\) 条边,总复杂度 \(O(n \log^2 n)\)
代码
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 5;
const int maxm = 5e6 + 5;
const int inf = 0x3f3f3f3f;
struct item
{
int u, v;
ll w;
bool operator < (const item& rhs) const { return (w < rhs.w); }
} edge[maxm];
struct node
{
int p;
ll w;
bool operator < (const node& rhs) const { return (w < rhs.w); }
} nd[maxn];
int n, ecnt, len;
int rt, max_sz, tot_sz;
int w[maxn], fa[maxn], sz[maxn];
bool vis[maxn];
ll dep[maxn];
vector<int> g[maxn], l[maxn];
inline int read()
{
int res = 0;
char ch = getchar();
while ((ch < '0') || (ch > '9')) ch = getchar();
while ((ch >= '0') && (ch <= '9'))
{
res = res * 10 + ch - '0';
ch = getchar();
}
return res;
}
inline void get_sz(int u, int fa)
{
sz[u] = 1;
for (int v : g[u])
{
if ((v == fa) || vis[v]) continue;
get_sz(v, u);
sz[u] += sz[v];
}
}
inline void get_rt(int u, int fa)
{
int res = 0;
for (int v : g[u])
{
if ((v == fa) || vis[v]) continue;
get_rt(v, u);
res = max(res, sz[v]);
}
res = max(res, tot_sz - sz[u]);
if (res < max_sz) rt = u, max_sz = res;
}
inline void get_dep(int u, int fa)
{
nd[++len] = (node){u, w[u] + dep[u]};
for (int i = 0; i < g[u].size(); i++)
{
int v = g[u][i], d = l[u][i];
if ((v == fa) || vis[v]) continue;
dep[v] = dep[u] + d;
get_dep(v, u);
}
}
inline void solve(int u)
{
// printf("solving %d\n", u);
vis[u] = true;
nd[len = 1] = (node){u, w[u]};
for (int i = 0; i < g[u].size(); i++)
{
int v = g[u][i], d = l[u][i];
if (vis[v]) continue;
dep[v] = d;
get_dep(v, u);
}
sort(nd + 1, nd + len + 1);
// printf("debug %d\n", dep[nd[2].p]);
for (int i = 2; i <= len; i++) edge[++ecnt] = (item){nd[i].p, nd[1].p, nd[i].w + nd[1].w};
for (int v : g[u])
{
if (vis[v]) continue;
get_sz(v, 0);
tot_sz = sz[v], max_sz = inf;
get_rt(v, 0);
solve(rt);
}
}
inline int get(int u) { return (fa[u] == u ? u : fa[u] = get(fa[u])); }
inline void kruskal()
{
ll ans = 0;
sort(edge + 1, edge + ecnt + 1);
for (int i = 1; i <= n; i++) fa[i] = i;
for (int i = 1; i <= ecnt; i++)
{
int fu = get(edge[i].u), fv = get(edge[i].v);
if (fu != fv) { fa[fu] = fv, ans += edge[i].w; }
}
printf("%lld\n", ans);
}
int main()
{
n = read();
for (int i = 1; i <= n; i++) w[i] = read();
for (int i = 1, u, v, d; i <= n - 1; i++)
{
u = read(), v = read(), d = read();
g[u].push_back(v), l[u].push_back(d);
g[v].push_back(u), l[v].push_back(d);
}
tot_sz = n, max_sz = inf;
get_sz(1, 0), get_rt(1, 0);
solve(rt);
kruskal();
return 0;
}
/*
4
1 3 5 1
1 2 1
2 3 2
3 4 3
4 <-> 1 : 8
3 <-> 1 : 9
2 <-> 1 : 8
2 <-> 1 : 11
1 <-> 1 : 4
2 <-> 1 : 5
3 <-> 1 : 9
4 <-> 1 : 8
1 <-> 1 : 2
3 <-> 4 : 9
4 <-> 4 : 8
4 <-> 4 : 2
*/