【AtCoder3611】Tree MST(点分治,最小生成树)
【AtCoder3611】Tree MST(点分治,最小生成树)
题面
AtCoder
洛谷
给定一棵\(n\)个节点的树,现有有一张完全图,两点\(x,y\)之间的边长为\(w[x]+w[y]+dis(x,y)\),其中\(dis\)表示树上两点的距离。
求完全图的\(MST\)。
题解
首先连边的这个式子可以直接转换成树上的两点间的路径,所以接下来只考虑\(dis(x,y)\)。
考虑\(Boruvka\)算法的执行过程,每次都会选择到达一个点集最近的一个点,然后将他们连边。
现在考虑模拟这个过程,那么在树上我们钦定一个点作为根节点,考虑过根节点的路径的连边情况。
对于每个点我们要找到离他最近的点,因此显然就是在其他子树内找到一个距离根节点最小的点然后让这个点向所有其他的点连边,这样子对于每个根节点我们都可以找到唯一的点,然后让它向其他所有点连边就好了。
显然不用以所有点为根,只需要把当前根节点丢掉,把子树再单独处理就好了。
不难发现点分治就是这么一个过程,因此对于每个分治重心找到距离根节点距离最近的点。
这样子一来点分治是\(O(nlogn)\),边数是\(O(nlogn)\),最后再跑一遍克鲁斯卡尔。
因此总的复杂度就是\(O(nlog^2n)\)。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define MAX 200200
#define ll long long
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
struct Line{int v,next,w;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v,int w){e[cnt]=(Line){v,h[u],w};h[u]=cnt++;}
struct Edge{int u,v;ll w;}E[MAX*50];
bool operator<(Edge a,Edge b){return a.w<b.w;}
int Size,mx,rt,size[MAX];bool vis[MAX];
int n,m,W[MAX];
void Getroot(int u,int ff)
{
int ret=0;size[u]=1;
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(vis[v]||v==ff)continue;
Getroot(v,u);size[u]+=size[v];
ret=max(ret,size[v]);
}
ret=max(ret,Size-size[u]);
if(ret<mx)mx=ret,rt=u;
}
int P;ll Val;
void dfs(int u,int ff,ll dep)
{
if(dep+W[u]<Val)Val=dep+W[u],P=u;
for(int i=h[u];i;i=e[i].next)
if(!vis[e[i].v]&&e[i].v!=ff)
dfs(e[i].v,u,dep+e[i].w);
}
void Link(int u,int ff,ll dep)
{
E[++m]=(Edge){u,P,Val+dep+W[u]};
for(int i=h[u];i;i=e[i].next)
if(!vis[e[i].v]&&e[i].v!=ff)
Link(e[i].v,u,dep+e[i].w);
}
void Divide(int u)
{
vis[u]=true;
Val=1e18;P=0;dfs(u,0,0);Link(u,0,0);
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(vis[v])continue;
Size=mx=size[v];Getroot(v,u);
Divide(rt);
}
}
int f[MAX];ll ans;
int getf(int x){return x==f[x]?x:f[x]=getf(f[x]);}
int main()
{
n=read();
for(int i=1;i<=n;++i)W[i]=read();
for(int i=1;i<n;++i)
{
int u=read(),v=read(),w=read();
Add(u,v,w);Add(v,u,w);
}
Size=mx=n;Getroot(1,0);Divide(rt);
sort(&E[1],&E[m+1]);
for(int i=1;i<=n;++i)f[i]=i;
for(int i=1;i<=m;++i)
{
int u=getf(E[i].u),v=getf(E[i].v);
if(u==v)continue;
ans+=E[i].w;f[u]=v;
}
printf("%lld\n",ans);
return 0;
}