AT3611 Tree MST
给定一棵\(n\)个节点的树,现有有一张完全图,两点\(x,y\)之间的边长为\(w[x]+w[y]+dis(x,y)\),其中\(dis\)表示树上两点的距离。
求完全图的\(MST\)。
\(2\ \leq\ N\ \leq\ 200,000\)
Orz mrsrz
先对每个点 \(u\) 建一个虚点 \(u'\) ,边权为 \(w[u]\) ,然后我们对每个点求出 \(f_u\) 表示距离 \(u'\) 最近的虚点的距离,以及 \(g_u\) 表示距离 \(u'\) 最近的虚点的原来的点的编号,可以直接换根dp求出。
然后对原树上的边 \((u,v,w)\) ,用新边 \((g_u,g_v,f_u+f_v+w)\) 做克鲁斯卡尔就可以得到答案。
考虑这么做的正确性,假设 \(u,v\) 之间没有连新边,那么对于 \(u,v\) 之间的点 \(x_1,x_2\dots x_k\) ,一定会有 \(g_{x_t}\neq u\) 或 \(g_{x_{t+1}}\neq v\) ,这说明一定存在若干个点之间有边,并且这些点会存在两个点满足 \(dist(x,u)\le dist(u,v)\) ,\(dist(y,v)\le dist(u,v)\) ,所以 \(u,v\) 这条边是多余的。
会比同等复杂度的B算法快2倍多。
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
const int N = 2e5;
using namespace std;
struct edges
{
int to,cost;
}edge[N * 2 + 5];
struct line
{
int u,v;
long long w;
}e[N * 2 + 5];
int n,w[N + 5],fa[N + 5],nxt[N * 2 + 5],head[N + 5],edge_cnt,g[N + 5],ec,cnt;
long long dis[N + 5],ans;
void add_edge(int u,int v,int w)
{
edge[++edge_cnt] = (edges){v,w};
nxt[edge_cnt] = head[u];
head[u] = edge_cnt;
}
void dfs1(int u,int fa)
{
g[u] = u;
dis[u] = w[u];
for (int i = head[u];i;i = nxt[i])
{
int v = edge[i].to,w = edge[i].cost;
if (v == fa)
continue;
dfs1(v,u);
if (dis[u] > dis[v] + w)
{
dis[u] = dis[v] + w;
g[u] = g[v];
}
}
}
void dfs2(int u,int fa)
{
for (int i = head[u];i;i = nxt[i])
{
int v = edge[i].to,w = edge[i].cost;
if (v == fa)
continue;
if (dis[v] > dis[u] + w)
{
dis[v] = dis[u] + w;
g[v] = g[u];
}
if (g[u] != g[v])
e[++ec] = (line){g[u],g[v],dis[u] + w + dis[v]};
dfs2(v,u);
}
if (g[u] != u)
e[++ec] = (line){u,g[u],dis[u] + w[u]};
}
bool cmp(line a,line b)
{
return a.w < b.w;
}
int find(int x)
{
if (fa[x] == x)
return x;
return fa[x] = find(fa[x]);
}
int main()
{
scanf("%d",&n);
for (int i = 1;i <= n;i++)
scanf("%d",&w[i]);
int u,v,w;
for (int i = 1;i < n;i++)
{
scanf("%d%d%d",&u,&v,&w);
add_edge(u,v,w);
add_edge(v,u,w);
}
dfs1(1,0);
dfs2(1,0);
sort(e + 1,e + ec + 1,cmp);
for (int i = 1;i <= n;i++)
fa[i] = i;
for (int i = 1;i <= ec;i++)
{
int X = find(e[i].u),Y = find(e[i].v);
if (X != Y)
{
fa[X] = Y;
ans += e[i].w;
cnt++;
}
if (cnt == n - 1)
{
cout<<ans<<endl;
return 0;
}
}
return 0;
}