【AT3611】Tree MST
这个题的输入首先就是一棵树,我们考虑一下点分
我们对于每一个分治重心考虑一下跨过这个分治重心的连边情况
就是把当前分治区域内所有的点向距离分治重心最近的点连边
考虑一下这个算法的正确性,如果我们已经对一个联通块内部形成了一个\(mst\),我们需要把这个联通块和另外一个联通块合并
如果这个新的联通块出现会使得原来联通块的\(mst\)改变,那么新出现的边也只会是原来联通块的点和新联通块到这个点距离最近的点之间的边,而这些最近的点又都是一个,所以我们就可以大大简化连边数量了
考虑这样一个结论:
对于图 \(G=(V,E)\) ,对于两个不相交的点集 \(V_1,V_2\) ,如果
一条边没有出现在 \(V_1\) 的导出子图的 MST 中,那么也不会出现
在 \(G\) 的 MST 中。
扩展到多个点集也是成立的。正确性考虑 Kruskal 的过程。
那么考虑点分治,假设对于接下来的各个分治结构都已经求得了 MST,
那么需要做的是合并当前的多个 MST。
注意边权 \(w_x+w_y+dis_{x,y}\) 可以写成 \((w_x+dis_x) +(w_y+dis_y)\),
其中 \(dis_x\) 是到当前分治中心的距离,那么最优连边的方案找一个点 \(x\) 使
得 \(w_x+dis_x\) ,之后其余点都向这些点连边。
所以这个点分的过程就相当于合并 MST 的过程。
连边数量是 \(n\log n\) 级别,再进行一次 kruskal,复杂度为 \(O(n\log^2n)\)
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=2e5+5;
struct E{int v,nxt,w;}e[maxn<<1];
struct Edge{int a,b;LL c;}E[maxn*55];
int sum[maxn],vis[maxn],head[maxn],mx[maxn],a[maxn],fa[maxn],sz[maxn];
int n,num,m,dx,S,rt;LL dw,ans,pre[maxn];
inline void add(int x,int y,int z) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;e[num].w=z;
}
void getroot(int x,int fa) {
sum[x]=1,mx[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
getroot(e[i].v,x);sum[x]+=sum[e[i].v];
mx[x]=max(mx[x],sum[e[i].v]);
}
mx[x]=max(mx[x],S-sum[x]);
if(mx[x]<mx[rt]) rt=x;
}
void getdis(int x,int fa) {
E[++m]=(Edge){dx,x,pre[x]+a[x]+dw};
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
getdis(e[i].v,x);
}
}
void chk(int x,int fa) {
if(pre[x]+a[x]<dw) dw=pre[x]+a[x],dx=x;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
pre[e[i].v]=pre[x]+e[i].w;chk(e[i].v,x);
}
}
void dfs(int x) {
dx=x,dw=a[x];vis[x]=1;pre[x]=0,chk(x,0),getdis(x,0);
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
S=sum[e[i].v],rt=0,getroot(e[i].v,0),dfs(rt);
}
}
inline int cmp(Edge A,Edge B) {return A.c<B.c;}
inline int find(int x) {return x==fa[x]?x:fa[x]=find(fa[x]);}
inline int merge(int x,int y) {
int xx=find(x),yy=find(y);
if(xx==yy) return 0;
if(sz[xx]<sz[yy]) fa[xx]=yy,sz[yy]+=sz[xx];
else fa[yy]=xx,sz[xx]+=sz[yy];
return 1;
}
int main() {
n=read();
for(re int i=1;i<=n;i++) a[i]=read();
for(re int x,y,z,i=1;i<n;i++)
x=read(),y=read(),z=read(),add(x,y,z),add(y,x,z);
mx[0]=n+1,S=n,rt=0,getroot(1,0),dfs(rt);
std::sort(E+1,E+m+1,cmp);
for(re int i=1;i<=n;i++) sz[i]=1,fa[i]=i;
for(re int i=1;i<=m;i++) if(merge(E[i].a,E[i].b)) ans+=E[i].c;
std::cout<<ans;
return 0;
}