Loading

题解 CF414E 【Mashmokh's Designed Problem】

题意

给定一棵 \(n\) 个节点的有根树,每个点连出的边都有序,共有 \(m\) 个操作。操作分三类:

  1. 查询两个点 \(u,v\) 的距离
  2. \(v\) 为根的子树从树中分开,并添加一条与其第 \(h\) 个祖先的连边作为该祖先的最后一个儿子。
  3. 查询从一个点出发,按边的顺序进行dfs,深度为 \(k\) 的最后遍历的点

\((n \le 10^5,m \le 10^5)\)

题解

如果只有 \(1\)\(2\) 似乎可以 \(\mathbb{LCT}\) 做,但边有顺序无疑让其无法实现。

树的遍历想到欧拉序 (别问我怎么想到的),通过欧拉序,边就都有序了,于是可以完成以上的操作。

记节点 \(i\) 在欧拉序中第一遍为 \(in[i]\) ,第二遍为 \(out[i]\)

欧拉序:\(1\ 2\ 3\ 4\ 4\ 3\ 2\ 1\)

操作一

如果我们记节点 \(i\) 的深度为 \(dep_i\),那么 \(u\to v\) 的距离就是 \(dep_u+dep_v-2\times dep_{lca(u,v)}\)

假如我们已经能够维护 \(u,v\) 的深度那么考虑 \(\rm lca\) 的深度与欧拉序有什么关系。

如果我们查\(5\to3\)\(\rm lca\)

欧拉序:\(\color{black}1\ 2\ 5\ \color{red}5\ 2\ 4\ 4\ 3\ 6\ \color{black}3\ 1\)

发现在 \((in[u],out[v])\) 之间的所有点全在他们的 \(\rm lca\) 以下,取这一段深度的最小值加一即可。

不过注意此时 \(u\)\(v\) 先访问,如果不满足就强行swap(u,v)

取最小值操作我们可以用 \(\rm Splay\)实现。(别问我为什么不用线段树,因为还有其他操作(

操作二

用欧拉序我们可以轻松摘下一棵树再挂上去:

欧拉序:\(\color{black}1\ 2\ \color{red}4\ 6\ 6\ 5\ 5\ 4\ \color{black}2\ 3\ 3\ 1\)

\(4\) 的子树摘下来挂到 \(1\)

欧拉序: \(\color{black}1\ 2\ 2\ 3\ 3\ \color{blue}4\ 6\ 6\ 5\ 5\ 4\ \color{black}1\)

不难发现把 \(v\) 挂到 \(u\) 上就是把 \([in[v],out[v]]\) 这段区间插到 \(out[u]\) 之前,并把其深度都减少 \(h-1\)

现在大量的序列操作, \(\rm Splay\) 序列之王的用处就体现出来了。

于是现在的问题变为怎么找到 \(u\) ,其实就是 \([1,in[v])\)中最后一个深度为 \(dep[v]-h\) 的点,于是自然引入了操作3

操作三

对于一个深度为 \(k\) 的节点,如果他排列在最后,那么他在欧拉序中的出现也一定在最后。

那么我们对每个节点维护最大值 \(mx\) 与最小值 \(mn\) ,并优先访问右儿子(如果\(k\in [mn_l,mx_l]\)),再是当前节点、左儿子。

那么现在用 \(\mathcal{O}(\log n)\) 的时间,就可以找到最后一个深度为 \(k\) 的了。那么在 操作二寻找指定深度的节点也可以迎刃而解了。

\(\rm Splay\)维护序列这里就不细讲了,做几个模板就都会了。

代码

#include<bits/stdc++.h>
//#define faster
namespace fin{
	#ifdef faster
	char buf[1<<21],*p1=buf,*p2=buf;
	inline int getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
	#else
	inline int getc(){return getchar();}
	#endif
	template <typename T>inline void read(T& t){
		t=0;int f=0;char ch=getc();while (!isdigit(ch)){if(ch=='-')f = 1;ch=getc();}
	    while(isdigit(ch)){t=t*10+ch-48;ch = getc();}if(f)t=-t;
	}
	template <typename T,typename... Args> inline void read(T& t, Args&... args){read(t);read(args...);}
}
namespace fout{
	char buffer[1<<21];int p1=-1;const int p2 = (1<<21)-1;
	inline void flush(){fwrite(buffer,1,p1+1,stdout),p1=-1;}
	inline void putc(const char &x) {if(p1==p2)flush();buffer[++p1]=x;}
	template <typename T>void write(T x) {
		static char buf[15];static int len=-1;if(x>=0){do{buf[++len]=x%10+48,x/=10;}while (x);}else{putc('-');do {buf[++len]=-(x%10)+48,x/=10;}while(x);}
		while (len>=0)putc(buf[len]),--len;
	}
}
using namespace std;
const int N=2e5+10;
int n,m,in[N],out[N],id[N],dep[N],cnt;
vector<int>e[N];int root;
void dfs(int u,int dep){
	in[u]=++cnt,id[cnt]=u,::dep[cnt]=dep;
	for(auto v:e[u])dfs(v,dep+1);
	out[u]=++cnt,id[cnt]=u,::dep[cnt]=dep;
}
int f[N],ch[N][2],mn[N],mx[N],tag[N],sz[N];
#define min3(a,b,c) min((a),min(b,c))
#define max3(a,b,c) max((a),max(b,c))
void pushup(int x){
	int l=ch[x][0],r=ch[x][1];
	sz[x]=sz[l]+1+sz[r];
	mx[x]=max3(mx[l],dep[x],mx[r]);
	mn[x]=min3(mn[l],dep[x],mn[r]);
}
void add(int x,int val){
	mn[x]+=val,mx[x]+=val,tag[x]+=val,dep[x]+=val;
}
void pushdown(int x){
	int l=ch[x][0],r=ch[x][1];
	if(tag[x]){
		if(l)add(l,tag[x]);
		if(r)add(r,tag[x]);
	}tag[x]=0;
}
#define get(x) (ch[f[(x)]][1]==(x))
#define connect(x,y,p) {if(x)f[(x)]=(y);if(y)ch[(y)][(p)]=(x);}
void rotate(int x){
	int fa=f[x],ffa=f[fa],p1=get(x),p2=get(fa);
	connect(ch[x][p1^1],fa,p1);
	connect(fa,x,p1^1);
	connect(x,ffa,p2);
	pushup(fa);pushup(x);
}
void splay(int x,int goal=0){
	stack<int>q;q.push(x);
	for(int fa=x;f[fa]!=goal;fa=f[fa])q.push(f[fa]);
	while(q.size())pushdown(q.top()),q.pop();
	for(int fa;(fa=f[x])!=goal;rotate(x))
		if(f[fa]!=goal)rotate(get(x)==get(fa)?fa:x);
	if(!goal)root=x;
}
#define mid (l+r>>1)
int build(int l,int r){
	if(l<mid)ch[mid][0]=build(l,mid-1),f[ch[mid][0]]=mid;
	if(mid<r)ch[mid][1]=build(mid+1,r),f[ch[mid][1]]=mid;
	pushup(mid);return mid;
}
int findk(int x,int k){
	int l=ch[x][0],r=ch[x][1];pushdown(x);
	if(mn[r]<=k&&k<=mx[r])return findk(r,k);
	else if(dep[x]==k){splay(x);return id[x];}
	else return findk(l,k);
}
int split(int l,int r){
	splay(l);splay(r,l);return ch[r][0];
}
int getpre(int x){
	splay(x);x=ch[x][0];
	while(ch[x][1])x=ch[x][1];
	splay(x);return x;
}
int getsuf(int x){
	splay(x);x=ch[x][1];
	while(ch[x][0])x=ch[x][0];
	splay(x);return x;
}
int del(int l,int r){
	int L=getpre(l),R=getsuf(r);
	int x=split(L,R),y=f[x];
	ch[y][0]=f[x]=0;pushup(y);pushup(f[y]);
	return x;
}
signed main(){
	fin::read(n,m);
	for(int i=1,u,v;i<=n;i++){
		fin::read(u);
		for(int j=1;j<=u;j++)fin::read(v),e[i].push_back(v);
	}
	mn[0]=0x3f3f3f3f;
	dfs(1,0);build(1,cnt);
	while(m--){
		int op,u,v,k;
		fin::read(op);
		if(op==1){
			fin::read(u,v);
			splay(in[u]);int rku=sz[ch[in[u]][0]]+1,res=dep[in[u]];
			splay(in[v]);int rkv=sz[ch[in[v]][0]]+1;res+=dep[in[v]];
			if(rku>rkv)swap(u,v);u=split(in[u],out[v]);res-=(mn[u]-1)*2;
			fout::write(res);fout::putc('\n');
		}else if(op==2){
			fin::read(u,v);int val=v-1;
			splay(in[u]);v=findk(ch[in[u]][0],dep[in[u]]-v);
			u=del(in[u],out[u]);add(u,-val);
			int rt=getpre(out[v]);splay(rt);
			splay(out[v],rt);
			connect(u,out[v],0);
			pushup(out[v]),pushup(rt);
		}else{
			fin::read(k);
			splay(1);
			fout::write(findk(1,k));fout::putc('\n');
		}
	}
	fout::flush();
	return 0;
}
posted @ 2021-01-23 22:13  caigou233  阅读(115)  评论(1编辑  收藏  举报