Loading

求LCA&树上倍增法&洛谷4180&树链剖分——轻重链剖分

1LCA

对于有根树T的两个结点u、v,最近公共祖先LCA(T,u,v)表示一个结点x,满足x是u和v的祖先且x的深度尽可能大。在这里,一个节点也可以是它自己的祖先。

有时lca可以帮助我们解决许多问题,包括树上路径,如何快捷并高效的求出lca就成了一个重要的问题。

2求解

2.1树上倍增法

这是一个在线算法。

2.1.1 预处理

树上倍增法是一种十分重要的算法,多用于树上统计信息使用,不光是lca要用,只要是统计信息就可以考虑用树上倍增法。

我们设\(f_{k,i}\) 为i向根节点走\(2^i\)步的祖先,那么显然有\(f_{k,i}=f_{f_{k,i-1},i-1}\) 我们可以用深度作为阶段,从根节点往下依次处理。

因为以深度作为阶段,考虑用bfs作为预处理算法。在预处理f数组的同时,维护d数组,即节点深度。

代码:

queue<int> q;

inline void bfs(){
	q.push(s);d[s]=1;
	while(!q.empty()){
		int top=q.front();q.pop();
		for(int x=e.head[top];x;x=e.li[x].next){
			int to=e.li[x].to;
			if(d[to]) continue;
			d[to]=d[top]+1;
			f[to][0]=top;
			for(int i=1;i<=t;i++) f[to][i]=f[f[to][i-1]][i-1];
			q.push(to);
		}
	}
}

其中,我们用邻接表存边。注意,这里\(t=\log_2n\) n为节点总数。这么做的原因是有可能这n个节点是一条链,在这种情况下,显然t最大为\(\log_2n\)

具体实现看代码即可。

2.1.2 求lca

假设我们要求x和y的lca,不妨设x的深度比y要大(否则可以交换x和y),我们先利用二进制拆分把x和y调到同一深度。这一步明显是正确的,因为根据二进制拆分,所有的正整数都可以写成二的正整数幂的和的形式,所以我们对x和y的深度之差进行二进制拆分,经过有限次上调,可以把x调到y的位置。具体流程模仿二进制拆分,从大到小枚举i,如果上调\(2^i\)步到达的点的深度没有y深,那么上调x。

当x和y在同一深度时,如果他们相等,则他们的lca就是y,这说明y是x的一个祖先。否则,我们让x和y一起上调,在上调之后x和y不相等的情况下,尽量往上调,同样还是利用二进制拆分。这样,调完之后的x和y的父节点即为lca。

为什么要x和y上调之后不相等呢?因为我们要保证x和y不能调过了,调过了再往下调就比较麻烦。

代码:

inline int lca(int x,int y){
	if(d[x]<d[y]) Swap(x,y);
	for(int i=t;i>=0;i--) if(d[f[x][i]]>=d[y]) x=f[x][i];
	if(x==y) return x;
	for(int i=t;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}

2.1.3总代码:

时间复杂度\(O((n+m)\log n)\)

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define ld long double
#define ll long long
#define ull unsigned long long
#define N 500010
#define M number
using namespace std;

inline ll read()
{
	ll x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

inline void Swap(int &a,int &b){
	a^=b;b^=a;a^=b;
}

struct EDGE{
	struct edge{
		int next,to;
		inline void intt(int ne_,int to_){
			next=ne_;to=to_;
		}
	};
	edge li[N*2];int head[N],tail;
	
	inline void add(int from,int to){
		li[++tail].intt(head[from],to);
		head[from]=tail;
	}
};
EDGE e;

int f[N][20],d[N],n,m,s,t;

queue<int> q;

inline void bfs(){
	q.push(s);d[s]=1;
	while(!q.empty()){
		int top=q.front();q.pop();
		for(int x=e.head[top];x;x=e.li[x].next){
			int to=e.li[x].to;
			if(d[to]) continue;
			d[to]=d[top]+1;
			f[to][0]=top;
			for(int i=1;i<=t;i++) f[to][i]=f[f[to][i-1]][i-1];
			q.push(to);
		}
	}
}

inline int lca(int x,int y){
	if(d[x]<d[y]) Swap(x,y);
	for(int i=t;i>=0;i--) if(d[f[x][i]]>=d[y]) x=f[x][i];
	if(x==y) return x;
	for(int i=t;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}

int main(){
	n=read();m=read();s=read();
	t=(int)(log(n)/log(2))+1;
	for(int i=1;i<=n-1;i++){
		int from,to;
		from=read();to=read();
		e.add(from,to);e.add(to,from);
	}
	bfs();
	for(int i=1;i<=m;i++){
		int a,b;
		a=read();b=read();
		printf("%d\n",lca(a,b));
	}
	return 0;
}

注意几点:

  1. 邻接表存边无向边开2倍。
  2. 因为是无向边,注意不要让子节点在回到父节点,即这一行代码:
if(d[to]) continue;
  1. 代码中t的求法证明如下:

\(2^t=n,e^a=2,e^b=n\),则有\(2^t=(e^a)^t=e^{at}=n=e^b\) 所以\(t=b/a\),即\(\log_2n=\ln n/\ln 2\)

c++中,log以e为底,其中e为自然对数。

2.2 tarjan求lca

这是一个离线算法。

2.2.1思路

再树深度优先便利的任意时刻,树上的节点有3类:

  1. 已经访问但未回溯的节点,标记为1
  2. 已经访问并且已经回溯的节点,标记为2。
  3. 未访问的节点,不打标记。

那么对于一个标记为1的节点x,其所有祖先节点必定是标记为1。

若我们要求解x和另一个节点y的lca,如果y的标记是2,那么y和x的lca必定是y朝根节点走时遇到的第一个标记为1的节点。

所以接下来的问题是如何快捷的求出某个标记为2的节点朝根节点走遇到的第一个标记为1的节点。

我们可以借助并查集求解。

2.2.2并查集使用

并查集代码:

struct DSU{
	int fa[N];
	
	inline void init(){
		for(int i=1;i<=n;i++) fa[i]=i;
	}
	
	inline int find(int x){
		return x==fa[x]?x:fa[x]=find(fa[x]);
	}
	
	inline bool hebing(int x,int y){
		int fax=find(x),fay=find(y);
		if(fax==fay) return 0;
		fa[fax]=fay;
		return 1;
	}
	
	inline void print(){
		for(int i=1;i<=n;i++) printf("%d ",fa[i]);
		printf("\n");
	}
};
DSU bcj;

怎么求解:我们只需要在某个节点u标记为2之后,即回溯之后,把它和它的父节点所在集合合并,因为这时以u为根节点的子树一定全部标记为2,而u的父节点一定标记为1,所以这时如果查找节点u所在子树的任意节点所在集合的代表元素,一定是这个节点朝根节点走时遇到的第一个标记为1的节点。这时要注意,我们在合并时,必须要让标记为1的节点当代表元素。具体过程即是在调用上述并查集代码的hebing操作时,让x为u,y为u父节点。

2.2.3 tarjan流程

我们从根节点开始往下执行dfs,每访问一个节点k,把该节点的标记打上1,然后处理k的所有子节点。处理完后把k的每个子节点所在集合与k所在集合合并,之后查看所有关于k的询问,如果另一个节点的标记为2则该询问答案为另一个节点所在集合的代表元素。如果标记不为2,则不用管他,因为当这另一个节点标记为2的时候,k节点标记一定为2,这时候可以处理该询问。回答完毕后,把k节点标记2。

代码:

inline void tarjan(int k){
	vis[k]=1;
	for(int x=e.head[k];x;x=e.li[x].next){
		int to=e.li[x].to;
		if(vis[to]) continue;
		tarjan(to);
		bcj.hebing(to,k);
	}
	for(int i=0;i<ask[k].size();i++){
		int y=ask[k][i];
		if(vis[y]!=2) continue;
		int lca=bcj.find(y);
		ans[ask_id[k][i]]=lca;
	}
	vis[k]=2;
}

2.2.4总代码

依然使用邻接表存图,这里使用两个vector来存每个点的询问以及该询问是第几个询问,开一个数组来保存答案。

详细实现见代码:

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define ld long double
#define ll long long
#define ull unsigned long long
#define N 500100
#define M number
using namespace std;

inline ll read()
{
	ll x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

int n,m,s;

struct EDGE{
	struct edge{
		int next,to;
		inline void intt(int ne_,int to_){
			next=ne_;to=to_;
		}
	};
	edge li[N*2];int head[N],tail;
	
	inline void add(int from,int to){
		li[++tail].intt(head[from],to);
		head[from]=tail;
	}
};
EDGE e;

struct DSU{
	int fa[N];
	
	inline void init(){
		for(int i=1;i<=n;i++) fa[i]=i;
	}
	
	inline int find(int x){
		return x==fa[x]?x:fa[x]=find(fa[x]);
	}
	
	inline bool hebing(int x,int y){
		int fax=find(x),fay=find(y);
		if(fax==fay) return 0;
		fa[fax]=fay;
		return 1;
	}
	
	inline void print(){
		for(int i=1;i<=n;i++) printf("%d ",fa[i]);
		printf("\n");
	}
};
DSU bcj;

vector<int> ask[N],ask_id[N];
int vis[N],ans[N];

inline void tarjan(int k){
	vis[k]=1;
	for(int x=e.head[k];x;x=e.li[x].next){
		int to=e.li[x].to;
		if(vis[to]) continue;
		tarjan(to);
		bcj.hebing(to,k);
	}
	for(int i=0;i<ask[k].size();i++){
		int y=ask[k][i];
		if(vis[y]!=2) continue;
		int lca=bcj.find(y);
		ans[ask_id[k][i]]=lca;
	}
	vis[k]=2;
}

int main(){
	n=read();m=read();s=read();
	bcj.init();
	for(int i=1;i<=n-1;i++){
		int from,to;
		from=read();to=read();
		e.add(from,to);e.add(to,from);
	}
	for(int i=1;i<=m;i++){
		int x,y;
		x=read();y=read();
		ask[x].push_back(y);ask_id[x].push_back(i);
		ask[y].push_back(x);ask_id[y].push_back(i);
	}
	tarjan(s);
	for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
	return 0;
}

自我感觉代码可读性极好

不像某dyk和某zyc

2.3树链剖分

树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。

在维护的时候,我们把树上的节点映射到一个区间上去,在这个区间上的节点编号即为树上节点的bfs序。

2.3.1定义与变量

我们先把这些链处理出来。

定义:节点k的重儿子为k的所有子节点中子树个数最大的那个。

定义:重边为节点k向它的重儿子连的一条边。

定义:重链为重边组成的树上的一段路径。

我们需要维护几个数组:

  1. \(size[k]\)表示以k为根子树大小
  2. \(deep[k]\)表示k节点深度。
  3. \(top[k]表示\)k所在重链的顶端节点。
  4. \(son[k]\)表示k节点的重儿子。
  5. \(fa[k]\)表示k节点的父亲节点。
  6. \(id[k]\)表示树上的节点k的dfs序,即节点k在区间上的节点编号。
  7. \(rk[g]\)表示区间节点编号为g的节点在树上的节点编号。

\(id\)\(rk\)相当于两个映射,前者把树上的点映射到区间上,后者把区间上的点映射到树上。

2.3.2预处理

我们用两个dfs来维护上述数组。

第一个dfs我们维护数组\(deep,fa,size,son\)

dfs两个参数,当前节点和当前节点的父节点。特别的,根节点的父节点是0号节点。

具体流程看代码即可。

inline void dfs1(int k,int f) {
	deep[k]=deep[f]+1;
	fa[k]=f;
	size[k]=1;
	for(int x=head[k];x;x=li[x].next) {
		int to=li[x].to;
		if(to==f) continue;
		dfs1(to,k);
		size[k]+=size[to];
		if(size[to]>size[son[k]]) son[k]=to;
	}
}

第二个dfs我们来维护数组\(id,rk,top\)

dfs两个参数,当前节点和当前节点所在重链的顶端节点。

特别注意,我们应该先去dfs重儿子在dfs其它子节点,其原因是要保证dfs序连续,方便我们的区间处理。

代码:

inline void dfs2(int k,int t) {
	id[k]=++tot;
	rk[tot]=k;
	top[k]=t;
	if(!son[k]) return;
	dfs2(son[k],t);
	for(int x=head[k]; x; x=li[x].next) {
		int to=li[x].to;
		if(to!=fa[k]&&to!=son[k]) dfs2(to,to);
	}
}

然后我们考虑知道以上信息如何求lca。

不妨设节点x和y,我们要求x和y的lca,且x的深度大于y。

如果两个节点x和y的重链顶端节点是同一个节点,那么说明y为x的祖先,直接返回y即可。

否则,我们把x上调至其重链顶端节点的父节点上,重复上述过程,即可求得lca。

中间一定要保证x的深度大于y的深度,否则交换x和y。

代码:

inline void lca(int x,int y){
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]]) Swap(x,y);
		x=fa[top[x]];
	}
	if(deep[x]>deep[y]) Swap(x,y);
	printf("%d\n",x);
}

2.3.3 总代码

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define ld long double
#define ll long long
#define ull unsigned long long
#define N 500010
#define M number
using namespace std;

const int INF=0x3f3f3f3f;

inline ll read()
{
	ll x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

inline void Swap(int &a,int &b){
	a^=b;b^=a;a^=b;
}

struct edge{
	int to,next;
	inline void intt(int to_,int ne_){
		to=to_;next=ne_;
	}
};
edge li[N*2];int head[N],tail;

inline void add(int from,int to){
	li[++tail].intt(to,head[from]);
	head[from]=tail;
}

int son[N],deep[N],fa[N],size[N],top[N];
int n,m,s;

inline void dfs1(int k,int f){
	deep[k]=deep[f]+1;
	fa[k]=f;
	size[k]=1;
	for(int x=head[k];x;x=li[x].next){
		int to=li[x].to;
		if(to==f) continue;
		dfs1(to,k);
		size[k]+=size[to];
		if(size[to]>size[son[k]]) son[k]=to;
	}
}

inline void dfs2(int k,int t){
	top[k]=t;
	if(!son[k]) return;
	dfs2(son[k],t);
	for(int x=head[k];x;x=li[x].next){
		int to=li[x].to;
		if(to!=fa[k]&&to!=son[k]) dfs2(to,to); 
	}
}

inline void lca(int x,int y){
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]]) Swap(x,y);
		x=fa[top[x]];
	}
	if(deep[x]>deep[y]) Swap(x,y);
	printf("%d\n",x);
}

int main(){
	scanf("%d%d%d",&n,&m,&s);
	for(int i=1;i<=n-1;i++){
		int from,to;
		scanf("%d%d",&from,&to);
		add(from,to);
		add(to,from);
	}
	dfs1(s,0);
	dfs2(s,s);
	for(int i=1;i<=m;i++){
		int a,b;
		scanf("%d%d",&a,&b);
		lca(a,b);
	}
	return 0;
}

2.3.4树链剖分其它应用

之前讲过熟练剖分可以把树上节点映射到区间上,所以所有对树上节点或者路径的操作都可以转化到区间操作,故也就可以用区间数据结构维护。

修改树上路径的信息,例如修改节点x到节点y的信息,就是先修改x到\(LCA(x,y)\)的信息,在修改\(LCA(x,y)\)到y的信息,所以修改的过程可以仿照求lca的过程,反正重链上的dfs序连续。

洛谷:轻重链剖分代码:

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define ld long double
#define ll long long
#define ull unsigned long long
#define N 100010
#define M number
using namespace std;

const int INF=0x3f3f3f3f;

inline ll read() {
	ll x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9') {
		if(ch=='-')f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9') {
		x=x*10+ch-'0';
		ch=getchar();
	}
	return x*f;
}

inline void Swap(int &a,int &b) {
	a^=b;
	b^=a;
	a^=b;
}

struct edge {
	int to,next;
	inline void intt(int to_,int ne_) {
		to=to_;
		next=ne_;
	}
};
edge li[N*2];
int head[N],tail;

inline void add(int from,int to) {
	li[++tail].intt(to,head[from]);
	head[from]=tail;
}

int son[N],deep[N],fa[N],size[N],top[N],id[N],rk[N],tot;
int n,m,r,mod,a[N];

struct ST {
	struct rode {
		int val,ad,len;
	};
	rode p[N*4];
	
	inline ST(){
		memset(p,0,sizeof(p));
	}

	inline void pushup(int k) {
		p[k].val=(p[k*2].val+p[k*2+1].val)%mod;
		p[k].len=p[k*2].len+p[k*2+1].len;
	}

	inline int A(int k,int x) {
		(p[k].val+=p[k].len*x%mod)%=mod;
		(p[k].ad+=x)%=mod;
	}

	inline void pushdown(int k) {
		A(k*2,p[k].ad);
		A(k*2+1,p[k].ad);
		p[k].ad=0;
	}
	
	inline void build(int k,int l,int r) {
		if(l==r) {
			p[k].val=a[rk[l]];
			p[k].len=1;
			return;
		}
		int mid=l+r>>1;
		build(k*2,l,mid);
		build(k*2+1,mid+1,r);
		pushup(k);
	}

	inline void change(int k,int l,int r,int z,int y,int x) {
		if(l==z&&r==y) {
			A(k,x);
			return;
		}
		if(p[k].ad) pushdown(k);
		int mid=l+r>>1;
		if(y<=mid) change(k*2,l,mid,z,y,x);
		else if(z>mid) change(k*2+1,mid+1,r,z,y,x);
		else change(k*2,l,mid,z,mid,x),change(k*2+1,mid+1,r,mid+1,y,x);
		pushup(k);
	}
	
	inline int ask_sum(int k,int l,int r,int z,int y){
		if(l==z&&r==y) return p[k].val;
		if(p[k].ad) pushdown(k);
		int mid=l+r>>1;
		if(y<=mid) return ask_sum(k*2,l,mid,z,y);
		else if(z>mid) return ask_sum(k*2+1,mid+1,r,z,y);
		else return (ask_sum(k*2,l,mid,z,mid)+ask_sum(k*2+1,mid+1,r,mid+1,y))%mod;
	}
};
ST stree;

inline void dfs1(int k,int f) {
	deep[k]=deep[f]+1;
	fa[k]=f;
	size[k]=1;
	for(int x=head[k];x;x=li[x].next) {
		int to=li[x].to;
		if(to==f) continue;
		dfs1(to,k);
		size[k]+=size[to];
		if(size[to]>size[son[k]]) son[k]=to;
	}
}

inline void dfs2(int k,int t) {
	id[k]=++tot;
	rk[tot]=k;
	top[k]=t;
	if(!son[k]) return;
	dfs2(son[k],t);
	for(int x=head[k]; x; x=li[x].next) {
		int to=li[x].to;
		if(to!=fa[k]&&to!=son[k]) dfs2(to,to);
	}
}

inline void update(int a,int b,int x){
	while(top[a]!=top[b]){
		if(deep[top[a]]<deep[top[b]]) Swap(a,b);
		stree.change(1,1,n,id[top[a]],id[a],x);
		a=fa[top[a]];
	}
	if(deep[a]>deep[b]) Swap(a,b);
	stree.change(1,1,n,id[a],id[b],x);
}

inline int ask(int a,int b){
	int ans=0;
	while(top[a]!=top[b]){
		if(deep[top[a]]<deep[top[b]]) Swap(a,b);
		(ans+=stree.ask_sum(1,1,n,id[top[a]],id[a]))%=mod;
		a=fa[top[a]];
	}
	if(deep[a]>deep[b]) Swap(a,b);
	(ans+=stree.ask_sum(1,1,n,id[a],id[b]))%=mod;
	return ans;
}

int main() {
	scanf("%d%d%d%d",&n,&m,&r,&mod);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]),a[i]%=mod;
	for(int i=1;i<=n-1;i++){
		int from,to;
		scanf("%d%d",&from,&to);
		add(from,to);
		add(to,from);
	}
	dfs1(r,0);
	dfs2(r,r);
	stree.build(1,1,n);
	for(int i=1;i<=m;i++){
		int op,x,y,z;
		scanf("%d",&op);
		if(op==1){
			scanf("%d%d%d",&x,&y,&z);
			update(x,y,z);
		}
		else if(op==2){
			scanf("%d%d",&x,&y);
			printf("%d\n",ask(x,y));
		}
		else if(op==3){
			scanf("%d%d",&x,&z);
			stree.change(1,1,n,id[x],id[x]+size[x]-1,z);
		}
		else if(op==4){
			scanf("%d",&x);
			printf("%d\n",stree.ask_sum(1,1,n,id[x],id[x]+size[x]-1));
		}
	}
	return 0;
}

3树上倍增法的应用

3.1洛谷4180

双倍经验LOJ

题目意思应该能看懂,这里主要是用树上倍增法来预处理一些数组,这里先简单说一下思路。

你需要先求出整张图的最小生成树来。

然后考虑,严格次小生成树一定是除了一条边的权值以外,其余边的权值都和最小生成树上的权值相同,否则一定不是严格次小。

换言之,如果不考虑权值重复的话,那么严格次小生成树和最小生成树只有一条边的差别。

我们考虑这个图上除了最小生成树以外的边,我们简称为树外边。

假设这个树外边的两个端点是x和y,显而易见的是,x和y在最小生成树(下面简称mst)上的路径与树外边\((x,y)\) 组成了一个环。

考虑树外边\((x,y)\)与该环最大值和次大值的关系。

如果树外边权值大于该环最大值,那么用树外边替代这个最大值的生成树的值可以作为我们的候选答案。

如果树外边权值等于该环最大值,那么用树外边替代这个次大值的生成树的值可以作为我们的候选答案。

树外边权值不可能小于该环最大值,否则不符合mst定义。

所以我们以任意一点为跟,用树上倍增法求一下任意节点到根节点路径上的最大值和次大值。显然,这个值也可以用树链剖分来做。

因为作者太懒,没有树链剖分代码

\(f_{x,i}\)表示节点x向上走\(2^i\)步的祖先。显然有\(f_{x,i}=f_{f_{x,i-1},i-1}\)

\(g_{x,i,k}\)表示节点x向上走\(2^i\)步所经过的边中的最大值与次大值,k=0最大值,k=1次大值。

那么有:

\[g_{x,i,0}=\max \{g_{x,i-1,0},g_{f_{x,i-1,0},i-1,0}\}\\ g_{x,i,1}=\max\limits_{g_{x,i-1,0}=g_{f_{x,i-1},i-1,0}}\{g_{x,i-1,1},g_{f_{x,i-1},x-1,1}\}\\ g_{x,i,1}=\max\limits_{g_{x,i-1,0}>g_{f_{x,i-1},i-1,0}}\{g_{x,i-1,1},g_{f_{x,i-1},x-1,0}\}\\ g_{x,i,1}=\max\limits_{g_{x,i-1,0}<g_{f_{x,i-1},i-1,0}}\{g_{x,i-1,0},g_{f_{x,i-1},x-1,1}\}\\ \]

由此,我们可以以深度为阶段,dp上述数组。

求完之后,我们可以利用求lca的过程,处理x和y到\(LCA(x,y)\)路径上的最大值和次大值,求法与g数组的类似。

但是我写的太丑了

现放代码如下:

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define int long long
#define ld long double
#define ll long long
#define ull unsigned long long
#define N 700010
#define M 900010
using namespace std;

const ll INF=1e14;

inline ll read()
{
	ll x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

inline ll Max(ll a,ll b){
	return a>b?a:b;
}

inline ll Min(ll a,ll b){
	return a>b?b:a;
}

inline void Swap(ll &a,ll &b){
	a^=b;b^=a;a^=b;
}

struct DSU{
	ll fa[N];
	
	inline void init(ll n){
		for(ll i=1;i<=n;i++) fa[i]=i;
	}
	
	inline ll find(ll x){
		return x==fa[x]?x:fa[x]=find(fa[x]);
	}
	
	inline bool merge(ll x,ll y){
		ll fax=find(x),fay=find(y);
		if(fax==fay) return 0;
		fa[fax]=fay;
		return 1;
	}
};

struct EDGE{
	struct edge{
		ll from,next,to,w;
		
		inline void read(ll ne_,ll from_,ll to_,ll w_){
			next=ne_;to=to_;from=from_;w=w_;
		}
		
		inline bool operator < (const edge &b) const {
			return w<b.w;
		}
	};
	edge li[M*2],ed[M*2];
	ll head[N],tail;
	
	inline void add(ll from,ll to,ll w,bool op){
		li[++tail].read(head[from],from,to,w);
		head[from]=tail;
		if(op) ed[tail/2].read(tail,from,to,w);
	}
};

DSU dsu;EDGE e;
ll n,m,f[N][20],g[N][20][2],d[N],ans=INF,t,minval;//0 max 1 2max
bool in_mst[M*2],vis[N];

inline ll kruskal(){
	ll sum=0;
	sort(e.ed+1,e.ed+1+m);
	for(ll i=1;i<=m;i++)
		if(dsu.merge(e.ed[i].from,e.ed[i].to)) in_mst[e.ed[i].next]=1,in_mst[e.ed[i].next-1]=1,sum+=e.ed[i].w;
	return sum;
}

queue<ll> q;
inline void prework(){
	for(int i=0;i<=t;i++) g[0][i][0]=g[0][i][1]=g[1][i][0]=g[1][i][1]=INF;
	for(int i=0;i<=n;i++) g[i][0][1]=-INF;
	q.push(1);d[1]=1;
	while(!q.empty()){
		ll top=q.front();q.pop();
		vis[top]=1;
		for(ll x=e.head[top];x;x=e.li[x].next){
			ll to=e.li[x].to;
			if(vis[to]) continue;
			if(!in_mst[x]) continue; 
			vis[to]=1;
			d[to]=d[top]+1;
			f[to][0]=top;g[to][0][0]=e.li[x].w;
			for(ll i=1;i<=t;i++){
				f[to][i]=f[f[to][i-1]][i-1];
				g[to][i][0]=Max(g[to][i-1][0],g[f[to][i-1]][i-1][0]);
				if(g[to][i-1][0]==g[f[to][i-1]][i-1][0]) g[to][i][1]=Max(g[to][i-1][1],g[f[to][i-1]][i-1][1]);
				else if(g[to][i-1][0]>g[f[to][i-1]][i-1][0]) g[to][i][1]=Max(g[to][i-1][1],g[f[to][i-1]][i-1][0]);
				else g[to][i][1]=Max(g[to][i-1][0],g[f[to][i-1]][i-1][1]);
			}
			q.push(to);
		}
	}
}

inline ll comp(ll fir,ll sec,ll w){
	if(w>fir) return minval-fir+w;
	return minval-sec+w; 
}

inline ll solve(ll x,ll y,ll w){
	if(d[x]<d[y]) Swap(x,y);
	ll maxx=-INF,cimaxx=-INF;
	for(int i=t;i>=0;i--){
		if(g[x][i][1]==INF) g[x][i][1]=-INF;
		if(g[x][i][0]==INF) g[x][i][0]=-INF;
		if(d[f[x][i]]>=d[y]){
			if(maxx<g[x][i][0]){
				cimaxx=Max(maxx,g[x][i][1]);
				maxx=g[x][i][0];
			}
			else if(maxx==g[x][i][0]) cimaxx=Max(cimaxx,g[x][i][1]);
			else cimaxx=Max(g[x][i][0],cimaxx);
			x=f[x][i];
		}
	}
	if(x==y) return comp(maxx,cimaxx,w);
	for(int i=t;i>=0;i--){
		if(g[x][i][1]==INF) g[x][i][1]=-INF;
		if(g[x][i][0]==INF) g[x][i][0]=-INF;
		if(g[y][i][1]==INF) g[y][i][1]=-INF;
		if(g[y][i][0]==INF) g[y][i][0]=-INF;
		if(f[x][i]!=f[y][i]){
			if(maxx<g[x][i][0]){
				cimaxx=Max(maxx,g[x][i][1]);
				maxx=g[x][i][0];
			}
			else if(maxx==g[x][i][0]) cimaxx=Max(cimaxx,g[x][i][1]);
			else cimaxx=Max(g[x][i][0],cimaxx);
			if(maxx<g[y][i][0]){
				cimaxx=Max(maxx,g[y][i][1]);
				maxx=g[y][i][0];
			}
			else if(maxx==g[y][i][0]) cimaxx=Max(cimaxx,g[y][i][1]);
			else cimaxx=Max(g[y][i][0],cimaxx);
			x=f[x][i];y=f[y][i];
		}
	}
	if(g[x][0][1]==INF) g[x][0][1]=-INF;
	if(g[x][0][0]==INF) g[x][0][0]=-INF;
	if(g[y][0][1]==INF) g[y][0][1]=-INF;
	if(g[y][0][0]==INF) g[y][0][0]=-INF;
	if(maxx<g[x][0][0]){
		cimaxx=Max(maxx,g[x][0][1]);
		maxx=g[x][0][0];
	}
	else if(maxx==g[x][0][0]) cimaxx=Max(cimaxx,g[x][0][1]);
	else cimaxx=Max(g[x][0][0],cimaxx);
	if(maxx<g[y][0][0]){
		cimaxx=Max(maxx,g[y][0][1]);
		maxx=g[y][0][0];
	}
	else if(maxx==g[y][0][0]) cimaxx=Max(cimaxx,g[y][0][1]);
	else cimaxx=Max(g[y][0][0],cimaxx);
	return comp(maxx,cimaxx,w);
}

signed main(){
	n=read();m=read();
	for(ll i=1;i<=m;i++){
		ll from=read(),to=read(),w=read();
		if(from==to) continue;
		e.add(from,to,w,0);e.add(to,from,w,1);
	}
	t=(ll)(log(n)/log(2))+1;
	dsu.init(n);
	minval=kruskal();
	prework();
	for(ll i=1;i<=m;i++){
		if(!in_mst[e.ed[i].next]){
			ans=Min(ans,solve(e.ed[i].from,e.ed[i].to,e.ed[i].w));
		}
	}
	printf("%lld\n",ans);
	return 0;
}

几个很毒瘤的地方:

  1. 洛谷上第10个点答案有\(10^{13}\)那么大,所以一定要注意你的INF取值。
  2. 很明显有些g数组的值是不存在的,这里我的做法是先将0号节点取为INF,这样所有2的幂跳过了的节点的值都为INF,不存在的值暂且为-INF,因为不能妨碍g数组取值。在求解x和y路径上的最大值和次大值时,我有把所有的INF换成了-INF,不妨碍最大值,次大值的取值。
  3. 不要忘记开long long
  4. 一定要多写函数。

4引用

  1. Dfkuaid树链剖分
  2. 百度百科——树链剖分
  3. 百度百科——lca
  4. 《算法竞赛进阶指南》
posted @ 2021-03-21 11:39  hyl天梦  阅读(128)  评论(0编辑  收藏  举报