河童重工 题解

题意

给定两棵 \(n\) 个点的带边权树。
求包含 \(n\) 个点的,\(i\)\(j\) 之间边权为 \(d_1(i, j) + d_2(i, j)\) 的完全图的最小生成树大小。
其中 \(d_x(i, j)\) 表示 \(i\)\(j\) 两点在第 \(x\) 棵树上的距离

\(1 ≤ n ≤ 10^5, 0 ≤ w ≤ 500\)

题解

完全图求 \(\text{MST}\) ,首先想到 \(\text{Boruka}\) ,但边权跟两棵树有关,不会了?

对一个图求 \(\text{MST}\) 有一个性质:把 \(G\) 划分为几个子图分别求 \(\text{MST}\) ,把留下的边合起来再求 \(\text{MST}\),与原图 \(\text{MST}\) 相同。

证明即根据 \(Kruskal\) 算法,划分为子图后 不会被丢掉的边一定不会被丢掉 。

那么怎么划分边集使其好计算呢?

诶,完全图的边集就是树上的所有路径构成的集合对吧。所有路径,点分治!

\(Tree_2\) 点分治,考虑计算当前分治块内、跨过分治中心的路径在完全图上的边构成的 \(\text{MST}\)

把与分治中心的距离当做点权,因为同一子树内的路径在之后会计算,此时加进去不会对答案有影响,所以只用考虑 \(Tree_1\) 的贡献。

对于 \(Tree_1\) ,问题变为点有点权 \(w\) ,求完全图的 \(\text{MST}\)\(e(u,v)=w_u+w_v+dis(u,v)\) 。如果你做过 \(Tree MST\) 那么此时已经做完了。

我们运用 \(\text{Boruka}\) ,每轮可以用 \(DP\) + 换根 \(DP\) 记录每个点最近的点和最近且颜色不同的点 和 其当前联通块颜色,然后把每个点贡献给其联通块即可。

为了保证复杂度,需要用到虚树,本来不存在的点的 \(w\) 赋为 \(\infty\) 即可,需要注意的是对虚树跑 \(\text{Boruka}\) 要在建虚树后更新树大小。

Code

//from 2022.3.19 18:44
#include<bits/stdc++.h>
#define ri register int
#define ll long long
#define pii pair<int,ll>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
using namespace std;
const int maxn = 1e5 + 10;
const ll inf = 0x3f3f3f3f3f3f3f3f,mx = 1e10;
template<class T> inline void rd(T &x){
	x = 0; char ch = getchar();
	while(!isdigit(ch)) ch = getchar();
	while(isdigit(ch)) x = x * 10 + ch - 48,ch = getchar();
}
int cnt,tmp[maxn],dsu[maxn];
ll w[maxn];
inline int find(int x) {return (dsu[x] == x) ? dsu[x] : (dsu[x] = find(dsu[x]));}
inline void merge(int u,int v) {u = find(u),v = find(v),dsu[u] = v;}
int dfn[maxn],n;
inline bool cmp(const int &x,const int &y) {return dfn[x] < dfn[y];}
struct edge{
	int u,v; ll w;
	friend bool operator <(edge x,edge y){return x.w < y.w;}
};
vector<edge> e;
struct Tree1{
	struct node{pii mn,sec;}f[maxn];
	int top[maxn],fa[maxn],sz[maxn],son[maxn],dep[maxn],c[maxn],dfnum,fr[maxn],st[maxn],stop;
	pii best[maxn];
	ll dis[maxn];
	vector<pii > to[maxn];
	vector<int> cur;
	void dfs1(int u,int pa){
		sz[u] = 1;
		for(auto v : to[u])
			if(v.fi ^ pa){
				dep[v.fi] = dep[u] + 1,fa[v.fi] = u,dis[v.fi] = dis[u] + v.se,dfs1(v.fi,u);
				sz[u] += sz[v.fi]; if(sz[v.fi] > sz[son[u]]) son[u] = v.fi;
			}
	}
	void dfs2(int u,int toop){
		top[u] = toop,dfn[u] = ++dfnum;
		if(son[u]) dfs2(son[u],toop);
		for(auto v : to[u])
			if((v.fi ^ fa[u]) && (v.fi ^ son[u])) dfs2(v.fi,v.fi);
	}
	inline int lca(int u,int v){
		while(top[u] ^ top[v]){
			if(dep[top[u]] < dep[top[v]]) swap(u,v);
			u = fa[top[u]];
		}
		if(dep[u] > dep[v]) swap(u,v); return u;
	}
	inline ll Dis(int u,int v) {return dis[u] + dis[v] - min(dis[u],dis[v])*2;}
	inline void add(int u,int v,int wi) {to[u].pb(mp(v,wi)),to[v].pb(mp(u,wi));}
	inline void vt(){//to bulid virtual trees
		cur.resize(0); tmp[++cnt] = 1;
		sort(tmp + 1,tmp + cnt + 1,cmp);
		if(tmp[2] != 1) w[1] = mx;
		else unique(tmp + 1,tmp + cnt + 1),cnt--;
		for(ri i = 1;i <= cnt;++i) cur.pb(tmp[i]);
		st[stop = 1] = 1;
		for(ri i = 2;i <= cnt;++i){
			int l = lca(st[stop],tmp[i]);
			if(dfn[l] == dfn[st[stop]]) {st[++stop] = tmp[i]; continue;}
			while(dfn[l] < dfn[st[stop-1]])
				add(st[stop-1],st[stop],Dis(st[stop-1],st[stop])),stop--;
			if(l == st[stop - 1]) add(st[stop],st[stop-1],Dis(st[stop],st[stop-1])),stop--;
			else add(l,st[stop],Dis(l,st[stop])),st[stop] = l,w[l] = mx,cur.pb(l);
			st[++stop] = tmp[i];
		}
		while(stop > 1) add(st[stop],st[stop-1],Dis(st[stop],st[stop-1])),stop--;
		for(auto it : cur) dsu[it] = it;
	}
	inline void update(node &now,node tr){
		if(tr.mn.se < now.mn.se){
			if((c[now.mn.fi] != c[tr.mn.fi]) && (now.mn.se < tr.sec.se)) tr.sec = now.mn;
			now.mn = tr.mn;
			if(c[now.mn.fi] == c[now.sec.fi]) now.sec.se = inf;
			if((tr.sec.se < now.sec.se) && (c[tr.sec.fi] ^ c[now.mn.fi])) now.sec = tr.sec;
		}
		else{
			if((tr.mn.se < now.sec.se) && (c[tr.mn.fi] ^ c[now.mn.fi])) now.sec = tr.mn;
			if((tr.sec.se < now.sec.se) && (c[tr.sec.fi] ^ c[now.mn.fi])) now.sec = tr.sec;
		}
	}
	void dfs(int u,int pa){
		for(auto v : to[u]) if(v.fi ^ pa) dfs(v.fi,u);
		f[u].mn.se = w[u] * 2,f[u].mn.fi = u;
		for(auto v : to[u])
			if(v.fi ^ pa){
				node tr = f[v.fi];
				tr.mn.se += v.se - w[v.fi] + w[u],tr.sec.se += v.se - w[v.fi] + w[u];
				update(f[u],tr);
			}
	}
	void rdfs(int u,int pa){
		for(auto v : to[u])
			if(v.fi ^ pa){
				int vv = v.fi,ww = v.se;
				node tr = f[u];
				tr.mn.se += ww - w[u] + w[vv],tr.sec.se += ww - w[u] + w[vv];
				update(f[vv],tr);
				rdfs(vv,u);
			}
	}
	void boruka(){
		vt(); int num = 0;
		cnt = cur.size();
		while(num < (cnt - 1)){
			for(auto it : cur)
				f[it].sec.se = inf,best[it] = mp(0,inf),c[it] = find(it);
			dfs(1,1),rdfs(1,1);
			for(auto it : cur){
				if(f[it].mn.se == inf) continue;
				pii tr = f[it].mn; if(c[it] == c[tr.fi]) tr = f[it].sec;
				if(tr.se < best[c[it]].se) best[c[it]] = tr,fr[c[it]] = it;
			}
			for(auto it : cur)
				if((find(it) == it) && (find(best[it].fi) != find(it))){
					if(best[it].se == inf) continue;
					e.pb((edge){fr[it],best[it].fi,best[it].se}),num++;
					merge(it,best[it].fi);
				}
		}
		for(auto it : cur) to[it].resize(0); cur.resize(0);
	}
}T1;//for vir_tree,boruka
struct Tree2{
	int sz[maxn],siz,rt,maxs[maxn];
	bool vis[maxn]; ll dis[maxn];
	vector<pii > to[maxn];
	inline void add(int u,int v,int w) {to[u].pb(mp(v,w)),to[v].pb(mp(u,w));}
	void dfs(int u,int f){
		tmp[++cnt] = u;
		for(auto v : to[u]) if(!vis[v.fi] && (v.fi ^ f)) w[v.fi] = w[u] + v.se,dfs(v.fi,u);
	}
	void getsz(int u,int f){
		sz[u] = 1;
		for(auto v : to[u]) if(!vis[v.fi] && (v.fi ^ f)) getsz(v.fi,u),sz[u] += sz[v.fi];
	}
	void findrt(int u,int f){
		maxs[u] = 0;
		for(auto v : to[u])
			if(!vis[v.fi] && (v.fi ^ f)) findrt(v.fi,u),maxs[u] = max(maxs[u],sz[v.fi]);
		maxs[u] = max(maxs[u],siz-sz[u]);
		if(maxs[u] < maxs[rt]) rt = u;
	}
	void dfz(int u){
		vis[u] = 1,cnt = 0,w[u] = 0,getsz(u,u),dfs(u,u);
		T1.boruka();
		for(auto v : to[u])
			if(!vis[v.fi]) rt = 0,siz = sz[v.fi],findrt(v.fi,v.fi),dfz(rt);
	}
	inline void work(){
		maxs[0] = n + 1,siz = n,rt = 0,getsz(1,1),findrt(1,1); dfz(rt);
	};
}T2;//for dfz
int main(){
	rd(n);
	for(ri i = 1,u,v,w;i < n;++i) rd(u),rd(v),rd(w),T1.add(u,v,w);
	for(ri i = 1,u,v,w;i < n;++i) rd(u),rd(v),rd(w),T2.add(u,v,w);
	T1.dfs1(1,1),T1.dfs2(1,1); for(ri i = 1;i <= n;++i) T1.to[i].resize(0);
	T2.work(); ll ans = 0;
	for(ri i = 1;i <= n;++i) dsu[i] = i;
	sort(e.begin(),e.end());
	for(auto ei : e) if(find(ei.u) ^ find(ei.v)) merge(ei.u,ei.v),ans += ei.w;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2022-03-20 22:53  Lumos壹玖贰壹  阅读(42)  评论(0编辑  收藏  举报