题解-AtCoder Code-Festival2017 Final-J Tree MST

Problem

\(\mathrm{Code~Festival~2017~Final~J}\)

题意概要:一棵 \(n\) 个节点有点权边权的树。构建一张完全图,对于任意一对点 \((x,y)\),连一条长度为 \(w[x] + w[y]+ dis(x, y)\) 的边。求这张图的最小生成树。

\(n\leq 2\times 10^5\)

Solution

在操场上晒太阳时想到的做法,求 \(\mathrm{MST}\) 可以使用另一种贪心算法:每次找到每个点连出去的最短的边,并将其合并,一次是 \(O(n)\),由于每次点数至少减半,所以总共不超过 \(\log n\)次,总复杂度 \(O(n\log n)\)

使用这种贪心算法后,只需每次找到离每个点最近的点。

可以使用点分治,设已经合并的点为同一连通块。考虑分治中心为 \(x\),只考虑过分治中心的路径,求出 \(dep+w\) 最小的点,对于每棵子树内的点,只有非子树内的点可能做贡献,而对于每个点,只有非同连通块的点可做贡献。所以需要维护四个值,这样较麻烦,或者是只维护两个值加上处理前后缀(具体可以看代码)。复杂度 \(O(n\log^2n)\)

然后搜了一波题解,发现一群人在同一天使用了同一个做法(可能是他们在讲课后统一发的题解):同样考虑上述贪心,只是点分治时不用考虑是否在同一子树内,而是都连过去,这样保证结果不会低于答案,稍加分析发现能得到最优解。

又看了看官方正解,发现不需要点分治,直接换根Dp即可……可能是老年选手已经开始老年痴呆了

Code

哦,这样常数有点大,我的代码跑极限数据 \(\mathrm {5.01s}\),会 T 三个点,预处理点分树即可

//Code Festival 2017 Final-J
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template <typename _tp> inline void cmax(_tp&A,const _tp&B){if(A < B) A = B;}
template <typename _tp> inline void cmin(_tp&A,const _tp&B){if(A > B) A = B;}

template <typename _tp> inline void read(_tp&x){
	char c11=getchar(),ob=0;x=0;
	while(c11!='-'&&!isdigit(c11))c11=getchar();if(c11=='-')ob=1,c11=getchar();
	while(isdigit(c11))x=x*10+c11-'0',c11=getchar();if(ob)x=-x;
}

const ll Inf = 2e18;
const int N = 201000;
struct Edge{int v,nxt;ll w;}a[N+N+N];
int head[N],Head[N],vs[N],w[N],id[N];
int n,_;

inline void ad(){
	static int x,y,z; read(x), read(y), read(z);
	a[++_].v = y, a[_].w = z, a[_].nxt = head[x], head[x] = _;
	a[++_].v = x, a[_].w = z, a[_].nxt = head[y], head[y] = _;
}

namespace dsu{
	int dad[N];
	int find(int x){return dad[x]? dad[x] = find(dad[x]): x;}
	bool check(int x,int y){return find(x) == find(y);}
	bool merge(int x,int y){
		static int p1,p2;
		if((p1 = find(x)) == (p2 = find(y))) return false;
		dad[p1] = p2; return true;
	}
}

namespace TD{
	int sz[N], rt, Mi, nn;
	void get_rt(int x,int las){
		sz[x] = 1;int mx = 0;
		for(int i=head[x];i;i=a[i].nxt)
			if(a[i].v!=las and !vs[a[i].v]){
				get_rt(a[i].v,x);
				sz[x] += sz[a[i].v];
				cmax(mx, sz[a[i].v]);
			}
		cmax(mx, nn - sz[x]);
		if(mx < Mi) Mi = mx, rt = x;
	}
	
	void Get_rt(int x,int xn){rt = 0, nn = xn, Mi = 2e9; get_rt(x,0);}
	
	void build(int x,int las){
		vs[x] = 1;
		a[++_].v = x, a[_].nxt = Head[las], Head[las] = _;
		get_rt(x,0);
		for(int i=head[x];i;i=a[i].nxt)
			if(!vs[a[i].v]){
				Get_rt(a[i].v,sz[a[i].v]);
				build(rt,x);
			}
	}
}

struct node{
	ll v;int id;
	inline node(){}
	inline node(const ll&V,const int&Id):v(V),id(Id){}
}tr[N], p[N], Mx, Mi, Fir[N], Sec[N];

node pre_fir[N], pre_sec[N];
node suf_fir[N], suf_sec[N];

inline void upd(node&A, node&B, node nw){
	if(nw.v < A.v) {
		if(nw.id == A.id) {A = nw; return ;}
		B = A, A = nw; return ;
	}
	if(nw.v < B.v)
		if(nw.id != A.id) B = nw;
}

void get_val(int x,int las,ll dep){
	upd(Mi, Mx, node(dep+w[x],id[x]));
	for(int i=head[x];i;i=a[i].nxt)
		if(a[i].v!=las and !vs[a[i].v])
			get_val(a[i].v,x,dep+a[i].w);
}

void cover(int x,int las,ll dep,node A,node B){
	if(id[x] != A.id and dep + A.v < p[x].v)
		p[x].v = dep + A.v, p[x].id = A.id;
	if(id[x] != B.id and dep + B.v < p[x].v)
		p[x].v = dep + B.v, p[x].id = B.id;
	for(int i=head[x];i;i=a[i].nxt)
		if(a[i].v!=las and !vs[a[i].v])
			cover(a[i].v,x,dep+a[i].w,A,B);
}

int to[N], to_w[N];

void work(int x){
	vs[x] = 1;
	int top = 0;
	for(int i=head[x];i;i=a[i].nxt)
		if(!vs[a[i].v]){
			Mi = node(w[x],id[x]), Mx = node(Inf,0);
			get_val(a[i].v,x,a[i].w);
			++top, to[top] = a[i].v, to_w[top] = a[i].w;
			Fir[top] = Mi, Sec[top] = Mx;
		}
	
	pre_fir[1] = Fir[1];
	pre_sec[1] = Sec[1];
	for(int i=2;i<=top;++i){
		pre_fir[i] = pre_fir[i-1];
		pre_sec[i] = pre_sec[i-1];
		upd(pre_fir[i], pre_sec[i], Fir[i]);
		upd(pre_fir[i], pre_sec[i], Sec[i]);
	}
	suf_fir[top] = Fir[top];
	suf_sec[top] = Sec[top];
	for(int i=top-1;i>=1;--i){
		suf_fir[i] = suf_fir[i+1];
		suf_sec[i] = suf_sec[i+1];
		upd(suf_fir[i], suf_sec[i], Fir[i]);
		upd(suf_fir[i], suf_sec[i], Sec[i]);
	}
	
	node A,B;
	for(int i=1;i<=top;++i){
		A = node(w[x],id[x]), B = node(Inf,0);
		if(i!=1) upd(A,B,pre_fir[i-1]);
		if(i!=1) upd(A,B,pre_sec[i-1]);
		if(i!=top) upd(A,B,suf_fir[i+1]);
		if(i!=top) upd(A,B,suf_sec[i+1]);
		cover(to[i],x,to_w[i],A,B);
	}
	
	if(top){
		A = pre_fir[top], B = pre_sec[top];
		if(id[x] != A.id and A.v < p[x].v)
			p[x].v = A.v, p[x].id = A.id;
		if(id[x] != B.id and B.v < p[x].v)
			p[x].v = B.v, p[x].id = B.id;
	}
	
	for(int i=Head[x];i;i=a[i].nxt)
		work(a[i].v);
}

int main(){
	read(n);
	for(int i=1;i<=n;++i)read(w[i]);
	for(int i=1;i<n;++i)ad();
	
	TD::build(1,0);
	
	int Tot = n; ll Ans = 0ll;
	while(Tot > 1){
		for(int i=1;i<=n;++i) id[i] = dsu::find(i), p[i].v = Inf, vs[i] = 0;
		work(a[Head[0]].v);
		for(int i=1;i<=n;++i) tr[i].v = Inf;
		for(int i=1,t;i<=n;++i){
			t = dsu::find(i);
			if(tr[t].v > w[i] + p[i].v)
				tr[t].v = p[i].v + w[i], tr[t].id = p[i].id;
		}
		for(int i=1;i<=n;++i)
			if(dsu::find(i) == i){
				if(dsu::check(i, tr[i].id)) continue;
				Ans += tr[i].v, dsu::merge(i, tr[i].id);
				--Tot;
			}
	}
	printf("%lld\n",Ans);
	return 0;
}
posted @ 2019-03-03 21:20  oier_hzy  阅读(179)  评论(0编辑  收藏  举报