Loading

【题解】AT3611 Tree MST

喝,长大了🎵↑↓↑↓

思路

点分治 + kruskal.

对于完全图 MST 问题有一个结论:

\(\forall G = (V, E)\),若 \(E_1 \cup ... \cup E_m = E\),设 \(E_i\) 经过 kruskal 得到的有效边集为 \(E^{\prime}_i\),则对 \(E^{\prime}_1 \cup ... \cup E^{\prime}_m\) 进行 kruskal 得到的是 \(G\) 的 MST。

又发现题目考虑的是树上路径问题,所以想到点分治。

因为点分治可以考虑到树上的每一条路径,它们取并显然可以覆盖所有边,所以可以直接在点分的时候求当前子树的 MST。

因为点分的时候钦定子树重心为根,所以树上路径可以换一种形式写。

\(dep_u\) 表示从当前分治的根结点到点 \(u\) 的距离,那么在 \((u, v)\) 之间连边的代价是 \(w_u + w_v + dep_u + dep_v\)

可以把贡献拆开,\(w_u + dep_u\) 赋给 \(u\)\(v\) 同理。

那么只需要在子树中找出 \(w_u + dep_u\) 最小的点 \(u\),然后子树内的其他所有点都向点 \(u\) 连边即可。

那么只需要进行一遍点分就可以找出所有有用的点,直接 kruskal 一次就行。

点分的复杂度是 \(O(n \log n)\),所以至多有 \(O(n \log n)\) 条边,总复杂度 \(O(n \log^2 n)\)

代码

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

typedef long long ll;

const int maxn = 2e5 + 5;
const int maxm = 5e6 + 5;
const int inf = 0x3f3f3f3f;

struct item
{
    int u, v;
    ll w;

    bool operator < (const item& rhs) const { return (w < rhs.w); }
} edge[maxm];

struct node
{
    int p;
    ll w;

    bool operator < (const node& rhs) const { return (w < rhs.w); }
} nd[maxn];

int n, ecnt, len;
int rt, max_sz, tot_sz;
int w[maxn], fa[maxn], sz[maxn];
bool vis[maxn];
ll dep[maxn];
vector<int> g[maxn], l[maxn];

inline int read()
{
    int res = 0;
    char ch = getchar();
    while ((ch < '0') || (ch > '9')) ch = getchar();
    while ((ch >= '0') && (ch <= '9'))
    {
        res = res * 10 + ch - '0';
        ch = getchar();
    }
    return res;
}

inline void get_sz(int u, int fa)
{
    sz[u] = 1;
    for (int v : g[u])
    {
        if ((v == fa) || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

inline void get_rt(int u, int fa)
{
    int res = 0;
    for (int v : g[u])
    {
        if ((v == fa) || vis[v]) continue;
        get_rt(v, u);
        res = max(res, sz[v]);
    }
    res = max(res, tot_sz - sz[u]);
    if (res < max_sz) rt = u, max_sz = res;
}

inline void get_dep(int u, int fa)
{
    nd[++len] = (node){u, w[u] + dep[u]};
    for (int i = 0; i < g[u].size(); i++)
    {
        int v = g[u][i], d = l[u][i];
        if ((v == fa) || vis[v]) continue;
        dep[v] = dep[u] + d;
        get_dep(v, u);
    }
}

inline void solve(int u)
{
    // printf("solving %d\n", u);
    vis[u] = true;
    nd[len = 1] = (node){u, w[u]};
    for (int i = 0; i < g[u].size(); i++)
    {
        int v = g[u][i], d = l[u][i];
        if (vis[v]) continue;
        dep[v] = d;
        get_dep(v, u);
    }
    sort(nd + 1, nd + len + 1);
    // printf("debug %d\n", dep[nd[2].p]);
    for (int i = 2; i <= len; i++) edge[++ecnt] = (item){nd[i].p, nd[1].p, nd[i].w + nd[1].w};
    for (int v : g[u])
    {
        if (vis[v]) continue;
        get_sz(v, 0);
        tot_sz = sz[v], max_sz = inf;
        get_rt(v, 0);
        solve(rt);
    }
}

inline int get(int u) { return (fa[u] == u ? u : fa[u] = get(fa[u])); }

inline void kruskal()
{
    ll ans = 0;
    sort(edge + 1, edge + ecnt + 1);
    for (int i = 1; i <= n; i++) fa[i] = i;
    for (int i = 1; i <= ecnt; i++)
    {
        int fu = get(edge[i].u), fv = get(edge[i].v);
        if (fu != fv) { fa[fu] = fv, ans += edge[i].w; }
    }
    printf("%lld\n", ans);
}

int main()
{
    n = read();
    for (int i = 1; i <= n; i++) w[i] = read();
    for (int i = 1, u, v, d; i <= n - 1; i++)
    {
        u = read(), v = read(), d = read();
        g[u].push_back(v), l[u].push_back(d);
        g[v].push_back(u), l[v].push_back(d);
    }
    tot_sz = n, max_sz = inf;
    get_sz(1, 0), get_rt(1, 0);
    solve(rt);
    kruskal();
    return 0;
}

/*
4
1 3 5 1
1 2 1
2 3 2
3 4 3

4 <-> 1 : 8
3 <-> 1 : 9
2 <-> 1 : 8
2 <-> 1 : 11

1 <-> 1 : 4
2 <-> 1 : 5
3 <-> 1 : 9
4 <-> 1 : 8
1 <-> 1 : 2
3 <-> 4 : 9
4 <-> 4 : 8
4 <-> 4 : 2
*/
posted @ 2023-01-11 20:05  kymru  阅读(24)  评论(0编辑  收藏  举报