AT3611 Tree MST

给定一棵\(n\)个节点的树,现有有一张完全图,两点\(x,y\)之间的边长为\(w[x]+w[y]+dis(x,y)\),其中\(dis\)表示树上两点的距离。

求完全图的\(MST\)

\(2\ \leq\ N\ \leq\ 200,000\)


Orz mrsrz

先对每个点 \(u\) 建一个虚点 \(u'\) ,边权为 \(w[u]\) ,然后我们对每个点求出 \(f_u\) 表示距离 \(u'\) 最近的虚点的距离,以及 \(g_u\) 表示距离 \(u'\) 最近的虚点的原来的点的编号,可以直接换根dp求出。

然后对原树上的边 \((u,v,w)\) ,用新边 \((g_u,g_v,f_u+f_v+w)\) 做克鲁斯卡尔就可以得到答案。

考虑这么做的正确性,假设 \(u,v\) 之间没有连新边,那么对于 \(u,v\) 之间的点 \(x_1,x_2\dots x_k\) ,一定会有 \(g_{x_t}\neq u\)\(g_{x_{t+1}}\neq v\) ,这说明一定存在若干个点之间有边,并且这些点会存在两个点满足 \(dist(x,u)\le dist(u,v)\)\(dist(y,v)\le dist(u,v)\) ,所以 \(u,v\) 这条边是多余的。

会比同等复杂度的B算法快2倍多。

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
const int N = 2e5;
using namespace std;
struct edges
{
    int to,cost;
}edge[N * 2 + 5];
struct line
{
    int u,v;
    long long w;
}e[N * 2 + 5];
int n,w[N + 5],fa[N + 5],nxt[N * 2 + 5],head[N + 5],edge_cnt,g[N + 5],ec,cnt;
long long dis[N + 5],ans;
void add_edge(int u,int v,int w)
{
    edge[++edge_cnt] = (edges){v,w};
    nxt[edge_cnt] = head[u];
    head[u] = edge_cnt;
}
void dfs1(int u,int fa)
{
    g[u] = u;
    dis[u] = w[u];
    for (int i = head[u];i;i = nxt[i])
    {
        int v = edge[i].to,w = edge[i].cost;
        if (v == fa)    
            continue;
        dfs1(v,u);
        if (dis[u] > dis[v] + w)
        {
            dis[u] = dis[v] + w;
            g[u] = g[v];
        }
    }
}
void dfs2(int u,int fa)
{
    for (int i = head[u];i;i = nxt[i])
    {
        int v = edge[i].to,w = edge[i].cost;
        if (v == fa)
            continue;
        if (dis[v] > dis[u] + w)
        {
            dis[v] = dis[u] + w;
            g[v] = g[u];
        }
        if (g[u] != g[v])
            e[++ec] = (line){g[u],g[v],dis[u] + w + dis[v]};
        dfs2(v,u);
    }
    if (g[u] != u)
        e[++ec] = (line){u,g[u],dis[u] + w[u]};
}
bool cmp(line a,line b)
{
    return a.w < b.w;
}
int find(int x)
{
    if (fa[x] == x)
        return x;
    return fa[x] = find(fa[x]);
}
int main()
{
    scanf("%d",&n);
    for (int i = 1;i <= n;i++)
        scanf("%d",&w[i]);
    int u,v,w;
    for (int i = 1;i < n;i++)
    {
        scanf("%d%d%d",&u,&v,&w);
        add_edge(u,v,w);
        add_edge(v,u,w);
    }
    dfs1(1,0);
    dfs2(1,0);
    sort(e + 1,e + ec + 1,cmp);
    for (int i = 1;i <= n;i++)
        fa[i] = i;
    for (int i = 1;i <= ec;i++)
    {
        int X = find(e[i].u),Y = find(e[i].v);
        if (X != Y)
        {
            fa[X] = Y;
            ans += e[i].w;
            cnt++;
        }
        if (cnt == n - 1)
        {
            cout<<ans<<endl;
            return 0;
        }
    }
    return 0;
}
posted @ 2021-03-17 10:30  eee_hoho  阅读(46)  评论(0编辑  收藏  举报