[HDU3710] Battle Over Cities [树链剖分+线段树+并查集+kruskal+思维]

题面

一句话题意:

给定一张 N 个点, M 条边的无向连通图, 每条边上有边权 w .

求删去任意一个点后的最小生成树的边权之和.

思路

首先肯定要$kruskal$一下

考虑$MST$里面去掉一个点,得到一堆联通块,我们要做的就是用原图中剩下的边把这些联通块穿起来

考虑这个点$u$在$MST$上的位置,可以知道有两种边:一种是从$u$的任意一个儿子的子树连到$u$的子树外面的,一种是在$u$的两个儿子的子树之间连接的

第一种情况:

考虑边$(u,v)$,没有进入$MST$中,那么若它是某个节点$x$从子树内连到子树外的一条边,x会在哪里呢?

显然,对于$u->lca->v$的这个链上,$x$可以是除了$u,v,lca$这三个点以外的任何一个

第二种情况:

还是考虑$(u,v)$和$x$,可以发现,此时$x$一定是$u$和$v$的$lca$

对于去掉一个点的问题而言,我们只需要保存去掉了它以后的每个联通块向外连出去的最小的一条边就可以了【这个很显然,自己想】

那么我们可以通过树链剖分+线段树,以及用一次$dfs$来询问、$dfs$过程中合并子树信息的方式,来维护上面两种情况的边,得到每个点删掉以后得到的联通块出来的最小边,再对它们跑$kruskal$

注意这样做的复杂度是对的,因为每个$MST$边都会对两个点的询问做出1的贡献,询问完$MST$中所有的点的过程中做$kruskal$的总点数是$2\ast n$

总复杂度应该是$O((m-n)\log^2n+n\log n)$的样子,还是可以接受的

注意这道题的细节非常多,写的时候一定要静态查错!

Code

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cassert>
#include<vector>
#define ll long long
using namespace std;
inline int read(){
	int re=0,flag=1;char ch=getchar();
	while(!isdigit(ch)){
		if(ch=='-') flag=-1;
		ch=getchar();
	}
	while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
	return re*flag;
}
int n,m,base;
int dep[100010],siz[100010],son[100010],fa[100010],dfn[100010],back[1000010],clk,top[100010],re[100010],st[100010][20];
namespace ufs{
	int f[100010];
	void init(int s){for(register int i=1;i<=s;i++) f[i]=-1;}
	inline int find(int x){return ((f[x]<0)?x:(f[x]=find(f[x])));}
	inline int join(int x,int y){
		x=find(x);y=find(y);
		if(x==y) return 0;
		f[y]+=f[x];f[x]=y;return -f[y];
	}
}
namespace seg{//线段树
	int seg[400010];
	void build(int l,int r,int num){
		seg[num]=1e9;
		if(l==r) return;
		int mid=(l+r)>>1;
		build(l,mid,num<<1);build(mid+1,r,num<<1|1);
	}
	void change(int l,int r,int ql,int qr,int num,int val){
		if(l>=ql&&r<=qr){seg[num]=min(seg[num],val);return;}
		int mid=(l+r)>>1;
		if(mid>=ql) change(l,mid,ql,qr,num<<1,val);
		if(mid<qr) change(mid+1,r,ql,qr,num<<1|1,val);
	}
	int query(int l,int r,int pos,int num){
		if(l==r) return seg[num];
		int mid=(l+r)>>1;
		if(mid>=pos) return min(query(l,mid,pos,num<<1),seg[num]);
		else return min(query(mid+1,r,pos,num<<1|1),seg[num]);
	}
}
namespace t{//这是MST树,上面有倍增、树剖等操作
	int first[100010],cnte=-1;
	void init(){
		memset(first,-1,sizeof(first));
		cnte=-1;clk=0;
		memset(re,0,sizeof(re));memset(dfn,0,sizeof(dfn));
		memset(back,0,sizeof(back));memset(top,0,sizeof(top));
		memset(st,0,sizeof(st));
	}
	struct edge{
		int to,next,w;
	}a[200010];
	inline void add(int u,int v,int w){
//		cout<<"tree addedge "<<u<<' '<<v<<' '<<w<<'\n';
		a[++cnte]=(edge){v,first[u],w};first[u]=cnte;
		a[++cnte]=(edge){u,first[v],w};first[v]=cnte;
	}
	void dfs1(int u,int f){
		int i,v;
		dep[u]=dep[f]+1;fa[u]=f;st[u][0]=f;
		siz[u]=1;son[u]=0;
		for(i=first[u];~i;i=a[i].next){
			v=a[i].to;if(v==f) continue;
			dfs1(v,u);
			siz[u]+=siz[v];
			if(siz[son[u]]<siz[v]) son[u]=v;
		}
	}
	void dfs2(int u,int t){
		int i,v;
		top[u]=t;
		dfn[u]=++clk;back[clk]=u;
		if(son[u]) dfs2(son[u],t);
		for(i=first[u];~i;i=a[i].next){
			v=a[i].to;if(v==fa[u]||v==son[u]) continue;
			dfs2(v,v);
		}
	}
	void ST(){
		for(int j=1;j<=18;j++){
			for(int i=1;i<=n;i++) st[i][j]=st[st[i][j-1]][j-1];
		}
	}
	int getlca(int l,int r){
		if(l==r) return l;
		while(top[l]!=top[r]){
			if(dep[top[l]]>=dep[top[r]]) swap(l,r);
			r=fa[top[r]];
		}
		if(dep[l]>dep[r]) swap(l,r);
		return l;
	}
	void solve(int l,int r,int val){
		int lca=getlca(l,r),tmp,i;
		if(dep[l]>dep[lca]+1){
			tmp=l;
			for(i=18;i>=0;i--) if(dep[st[tmp][i]]>dep[lca]+1) tmp=st[tmp][i];
			while(top[l]!=top[tmp]){
				seg::change(1,n,dfn[top[l]],dfn[l],1,val);
				l=fa[top[l]];
			}
			seg::change(1,n,dfn[tmp],dfn[l],1,val);
		}
		if(dep[r]>dep[lca]+1){
			tmp=r;
			for(i=18;i>=0;i--) if(dep[st[tmp][i]]>dep[lca]+1) tmp=st[tmp][i];
			while(top[r]!=top[lca]){
				seg::change(1,n,dfn[top[r]],dfn[r],1,val);
				r=fa[top[r]];
			}
			seg::change(1,n,dfn[tmp],dfn[r],1,val);
		}
	}
}
namespace ori{
	int cnte=-1;
	int id[100010],belong[100010];
	struct edge{
		int from,to,w,flag;
	}a[100010];
	vector<edge>e[100010];
	void init(){
		cnte=-1;
		memset(id,0,sizeof(id));
		memset(a,0,sizeof(a));
	}
	inline bool cmp(edge l,edge r){return l.w<r.w;}
	inline void add(int u,int v,int w){
//		cout<<"original add "<<u<<' '<<v<<' '<<w<<'\n';
		a[++cnte]=(edge){u,v,w,0};
	}
	int kruskal(){
		int i,u,v,ans=0,tmp,debug=0;
		sort(a,a+cnte+1,cmp);
		ufs::init(n);
		for(i=0;i<=cnte;i++){
			u=a[i].from;v=a[i].to;
			if(tmp=ufs::join(u,v)){
				a[i].flag=1;debug++;
				t::add(u,v,a[i].w);
				ans+=a[i].w;
				if(tmp==n) break;
			}
		}
		assert(debug==n-1);
		return ans;
	}
	void solve(){
		int i,tmp;
		for(i=0;i<=cnte;i++){//处理非MST边
			if(a[i].flag) continue;
			tmp=t::getlca(a[i].from,a[i].to);
			if(tmp!=a[i].from&&tmp!=a[i].to) e[tmp].push_back(a[i]);
			t::solve(a[i].from,a[i].to,a[i].w);
		}
	}
	inline int findb(int x){return ((belong[x]==x)?x:(belong[x]=findb(belong[x])));}//这里一定要再写一个集合合并
	void getans(int u,int fromw){//一次dfs处理每个点的答案
		int i,v,tot=0,x,y,cur=0,tmp=(n!=1),pre=e[u].size();
		if(u!=1) id[fa[u]]=++tot;
		for(i=t::first[u];~i;i=t::a[i].next){
			v=t::a[i].to;if(v==fa[u]) continue;
			getans(v,t::a[i].w);
			id[v]=++tot;
			fromw+=t::a[i].w;
			if(u!=1&&seg::query(1,n,dfn[v],1)!=1e9) e[u].push_back((edge){id[fa[u]],id[v],seg::query(1,n,dfn[v],1),0});//找最小边
		}
		for(i=0;i<pre;i++){
			e[u][i].from=id[findb(e[u][i].from)];
			e[u][i].to=id[findb(e[u][i].to)];
		}
		sort(e[u].begin(),e[u].end(),cmp);
		ufs::init(tot);
		for(i=0;i<e[u].size();i++){
			x=e[u][i].from;y=e[u][i].to;
			assert(x<=tot);
			assert(y<=tot);
			if(tmp=ufs::join(x,y)){
				cur+=e[u][i].w;
				if(tmp==tot) break;
			}
		}
		if(tot==1||tmp==tot) re[u]=base-fromw+cur;
		else re[u]=1e9;
		e[u].clear();
		for(i=t::first[u];~i;i=t::a[i].next) if(t::a[i].to!=fa[u]) belong[t::a[i].to]=u;
	}
}

int main(){
	int T=read(),i,j,t1,t2,t3,t4,tmp;
	while(T--){
		n=read();m=read();
		ori::init();t::init();
		for(i=1;i<=n;i++) ori::belong[i]=i;
		for(i=1;i<=m;i++){
			t1=read();t2=read();t3=read();t4=read();
			ori::add(t1,t2,t3*(1-t4));
		}
		base=ori::kruskal();
		t::dfs1(1,0);
		t::dfs2(1,1);
		t::ST();
		seg::build(1,n,1);
		ori::solve();
		ori::getans(1,0);
		for(i=1;i<=n;i++){
			if(re[i]!=1e9) printf("%d\n",re[i]);
			else puts("inf");
		}
		fflush(stdout);
	}
}
posted @ 2019-02-04 23:25  dedicatus545  阅读(267)  评论(0编辑  收藏  举报