[LOJ3046][ZJOI2019]语言:树链的并+线段树合并

分析

问题显然可以转化为对于每个节点询问所有这个节点的所有链的链并的大小。

考场上我直接通过树剖打标记+树剖线段树维护以\(O(n \log^3 n)\)的时间复杂度暴力实现了这个过程。(使用LCT或者全局平衡二叉树可以实现\(O(n \log^2 n)\)的时间复杂度)

考虑如何快速求出链并的大小,有这样一个结论:把所有的链的端点按dfs序排序后,链并的大小等于所有链的两端点的深度之和减去相邻端点的LCA的深度之和再减去所有端点的LCA的深度,这个结论(貌似)在链并是一个连通块的时候均成立。

有了这个结论,我们就可以快乐地线段树合并了,时间复杂度为\(O(n \log n)\)

代码

考场代码(\(O(n \log^3 n)\)

这个算法可以过掉本题,但是无法通过UOJ上的HACK数据。

// 舞台赋予了我们生存的意义 
// 让荣光停落于刀锋之上 
// ——Star Divine -Finale- 

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <utility>
#include <vector>
#include <queue>
#include <set>
#include <map>

#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int)a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
typedef long long LL;

using std::cerr;
using std::endl;

inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

const int MAXN=100005;

int n,m,ecnt,head[MAXN];

struct Edge{
	int to,nxt;
}e[MAXN<<1];

inline void add_edge(int bg,int ed){
	++ecnt;
	e[ecnt].to=ed;
	e[ecnt].nxt=head[bg];
	head[bg]=ecnt;
}

int fa[MAXN],dep[MAXN],siz[MAXN],pc[MAXN],top[MAXN];
int tot,id[MAXN],num[MAXN],len[MAXN],toparr[MAXN],cnt;

void dfs1(int x,int pre,int dept){
	fa[x]=pre;
	dep[x]=dept;
	siz[x]=1;
	int maxsiz=-1;
	trav(i,x){
		int ver=e[i].to;
		if(ver==pre)continue;
		dfs1(ver,x,dept+1);
		siz[x]+=siz[ver];
		if(siz[ver]>maxsiz){
			maxsiz=siz[ver];
			pc[x]=ver;
		}
	}
}

void dfs2(int x,int topf){
	top[x]=topf;
	id[x]=++tot;
	num[tot]=x;
	++len[topf];
	if(!pc[x])return;
	dfs2(pc[x],topf);
	trav(i,x){
		int ver=e[i].to;
		if(ver==fa[x]||ver==pc[x])continue;
		toparr[++cnt]=ver;
		dfs2(ver,ver);
	}
}

int ss[MAXN],tt[MAXN];
std::vector<int> vec[MAXN];

inline void set_tag(int x,int y,int pid){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])std::swap(x,y);
		vec[id[top[x]]].pb(pid);
		vec[id[x]+1].pb(-pid);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y])std::swap(x,y);
	vec[id[x]].pb(pid);
	vec[id[y]+1].pb(-pid);
}

#define mid ((l+r)>>1)

int minn[MAXN<<2],mncnt[MAXN<<2],tag[MAXN<<2],ql,qr,kk;
int sgn,root[MAXN],lc[MAXN<<2],rc[MAXN<<2];

inline void push_tag(int o,int _kk){
	minn[o]+=_kk;
	tag[o]+=_kk;
}

inline void pushdown(int o){
	if(!tag[o])return;
	push_tag(lc[o],tag[o]);
	push_tag(rc[o],tag[o]);
	tag[o]=0;
}

inline void pushup(int o){
	if(minn[lc[o]]<minn[rc[o]]){
		minn[o]=minn[lc[o]];
		mncnt[o]=mncnt[lc[o]];
	}
	else if(minn[lc[o]]>minn[rc[o]]){
		minn[o]=minn[rc[o]];
		mncnt[o]=mncnt[rc[o]];
	}
	else{
		minn[o]=minn[lc[o]];
		mncnt[o]=mncnt[lc[o]]+mncnt[rc[o]];
	}
}

int build(int l,int r){
	int o=++sgn;
	if(l==r){
		minn[o]=0;
		mncnt[o]=1;
		return o;
	}
	lc[o]=build(l,mid);
	rc[o]=build(mid+1,r);
	pushup(o);
	return o;
}

void add(int o,int l,int r){
	if(ql<=l&&r<=qr){
		push_tag(o,kk);
		return;
	}
	pushdown(o);
	if(mid>=ql)add(lc[o],l,mid);
	if(mid<qr)add(rc[o],mid+1,r);
	pushup(o);
}

/*
void write(int o,int l,int r){
	if(l==r){
		cerr<<minn[o]<<" ";
		return;
	}
	pushdown(o);
	write(lc,l,mid);
	write(rc,mid+1,r);
}
*/

#undef mid

int nowans;

inline void path_add(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])std::swap(x,y);
		if(minn[root[top[x]]]==0)nowans-=mncnt[root[top[x]]];
		ql=1,qr=id[x]-id[top[x]]+1;
		add(root[top[x]],1,len[top[x]]);
		if(minn[root[top[x]]]==0)nowans+=mncnt[root[top[x]]];
		x=fa[top[x]];
	}
	if(dep[x]>dep[y])std::swap(x,y);
	if(minn[root[top[x]]]==0)nowans-=mncnt[root[top[x]]];
	ql=id[x]-id[top[x]]+1,qr=id[y]-id[top[y]]+1;
	add(root[top[x]],1,len[top[x]]);
	if(minn[root[top[x]]]==0)nowans+=mncnt[root[top[x]]];
}

/*
inline int calc(){
	int ret=0;
	rin(i,1,cnt)if(minn[root[toparr[i]]]==0)ret+=mncnt[root[toparr[i]]];
	return ret;
}
*/

int main(){
	freopen("language.in","r",stdin);
	freopen("language.out","w",stdout);
	n=read(),m=read();
	rin(i,2,n){
		int u=read(),v=read();
		add_edge(u,v);
		add_edge(v,u);
	}
	toparr[++cnt]=1;
	dfs1(1,0,1);
	dfs2(1,1);
	rin(i,1,cnt)root[toparr[i]]=build(1,len[toparr[i]]);
	rin(i,1,m){
		ss[i]=read(),tt[i]=read();
		set_tag(ss[i],tt[i],i);
	}
	LL ans=0;nowans=n;
	rin(i,1,n){
		if(minn[root[top[num[i]]]]==0)nowans-=mncnt[root[top[num[i]]]];
		ql=qr=i-id[top[num[i]]]+1,kk=1;
		add(root[top[num[i]]],1,len[top[num[i]]]);
		if(minn[root[top[num[i]]]]==0)nowans+=mncnt[root[top[num[i]]]];
		rin(j,0,Size(vec[i])-1){
			int pid=vec[i][j];
			kk=1;
			if(pid<0)pid=-pid,kk=-1;
			path_add(ss[pid],tt[pid]);
		}
		ans+=n-nowans-1;
		if(minn[root[top[num[i]]]]==0)nowans-=mncnt[root[top[num[i]]]];
		ql=qr=i-id[top[num[i]]]+1,kk=-1;
		add(root[top[num[i]]],1,len[top[num[i]]]);
		if(minn[root[top[num[i]]]]==0)nowans+=mncnt[root[top[num[i]]]];
	}
	printf("%lld\n",ans>>1);
	return 0;
}

/*
5 3
1 2
1 3
3 4
3 5
3 4
1 4
2 5

8
*/

正解(\(O(n \log n)\)

#include <bits/stdc++.h>

#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int)a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
typedef long long LL;

using std::cerr;
using std::endl;

inline LL read(){
	LL x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

const int MAXN=100005;

int n,m,ecnt,head[MAXN];

struct Edge{
	int to,nxt;
}e[MAXN<<1];

inline void add_edge(int bg,int ed){
	++ecnt;
	e[ecnt].to=ed;
	e[ecnt].nxt=head[bg];
	head[bg]=ecnt;
}

int fa[MAXN],dep[MAXN],siz[MAXN],pc[MAXN],top[MAXN];
int tot,len,id[MAXN],num[MAXN],pos[MAXN],st[20][MAXN<<1];

void dfs1(int x,int pre,int dept){
	fa[x]=pre;
	dep[x]=dept;
	siz[x]=1;
	int maxsiz=-1;
	trav(i,x){
		int ver=e[i].to;
		if(ver==pre)continue;
		dfs1(ver,x,dept+1);
		siz[x]+=siz[ver];
		if(siz[ver]>maxsiz){
			maxsiz=siz[ver];
			pc[x]=ver;
		}
	}
}

void dfs2(int x,int topf){
	top[x]=topf;
	id[x]=++tot;
	num[tot]=x;
	pos[x]=++len;
	st[0][len]=id[x];
	if(!pc[x])return;
	dfs2(pc[x],topf);
	st[0][++len]=id[x];
	trav(i,x){
		int ver=e[i].to;
		if(ver==fa[x]||ver==pc[x])continue;
		dfs2(ver,ver);
		st[0][++len]=id[x];
	}
}

void build_st(){
	int lim=log2(len);
	rin(i,1,lim)rin(j,1,len-(1<<i)+1)st[i][j]=std::min(st[i-1][j],st[i-1][j+(1<<(i-1))]);
}

inline int lca(int x,int y){
	if(!x||!y)return 0;
	x=pos[x],y=pos[y];
	if(x>y)std::swap(x,y);
	int lim=log2(y-x+1);
	return num[std::min(st[lim][x],st[lim][y-(1<<lim)+1])];
}

int s[MAXN],t[MAXN];
std::vector<int> vec[MAXN];

int sgn,root[MAXN],lc[MAXN*40],rc[MAXN*40],lb[MAXN*40],rb[MAXN*40],cov[MAXN*40],loc;
LL sum[MAXN*40];

#define mid ((l+r)>>1)

inline void pushup(int o){
	sum[o]=sum[lc[o]]+sum[rc[o]]-dep[lca(num[rb[lc[o]]],num[lb[rc[o]]])];
	lb[o]=lb[lc[o]]>0?lb[lc[o]]:lb[rc[o]];
	rb[o]=rb[rc[o]]>0?rb[rc[o]]:rb[lc[o]];
}

int insert(int pre,int l,int r){
	int o=pre;
	if(!o)o=++sgn;
	if(l==r){
		++cov[o];
		sum[o]=dep[num[l]];
		lb[o]=rb[o]=l;
		return o;
	}
	if(loc<=mid)lc[o]=insert(lc[pre],l,mid);
	else rc[o]=insert(rc[pre],mid+1,r);
	pushup(o);
	return o;
}

void erase(int o,int l,int r){
	if(l==r){
		cov[o]-=2;
		if(!cov[o])sum[o]=lb[o]=rb[o]=0;
		return;
	}
	if(loc<=mid)erase(lc[o],l,mid);
	else erase(rc[o],mid+1,r);
	pushup(o);
}

int merge(int x,int y,int l,int r){
	if(!x||!y)return x+y;
	if(l==r){
		cov[x]+=cov[y];
		sum[x]|=sum[y];
		lb[x]|=lb[y];
		rb[x]|=rb[y];
		return x;
	}
	lc[x]=merge(lc[x],lc[y],l,mid);
	rc[x]=merge(rc[x],rc[y],mid+1,r);
	pushup(x);
	return x;
}

#undef mid

LL ans;

void dfs3(int x){
	trav(i,x){
		int ver=e[i].to;
		if(ver==fa[x])continue;
		dfs3(ver);
		root[x]=merge(root[x],root[ver],1,n);
	}
	rin(i,0,Size(vec[x])-1){
		if(vec[x][i]<0){
			loc=id[s[-vec[x][i]]];
			erase(root[x],1,n);
			loc=id[t[-vec[x][i]]];
			erase(root[x],1,n);
		}
	}
	rin(i,0,Size(vec[x])-1){
		if(vec[x][i]>0){
			loc=id[s[vec[x][i]]];
			root[x]=insert(root[x],1,n);
			loc=id[t[vec[x][i]]];
			root[x]=insert(root[x],1,n);
		}
	}
	if(!lb[root[x]])return;
	int l=lca(num[lb[root[x]]],num[rb[root[x]]]);
	ans+=sum[root[x]]-dep[fa[l]]-1;
}

int main(){
	n=read(),m=read();
	rin(i,2,n){
		int u=read(),v=read();
		add_edge(u,v);
		add_edge(v,u);
	}
	dfs1(1,0,1);
	dfs2(1,1);
	build_st();
	rin(i,1,m){
		s[i]=read(),t[i]=read();
		int l=lca(s[i],t[i]);
		vec[s[i]].pb(i);
		vec[t[i]].pb(i);
		vec[fa[l]].pb(-i);
	}
	dfs3(1);
	printf("%lld\n",ans>>1);
	return 0;
}

posted on 2019-05-08 08:16  ErkkiErkko  阅读(330)  评论(0编辑  收藏  举报