洛谷 P3224 [HNOI2012]永无乡(Splay合并)
传送门
解题思路
若干平衡树,每次操作有两种,一是合并两个Splay,二是查询某一个点所在的平衡树里的第k小的点的编号。
首先用并查集维护某个点在哪个平衡树里,然后rt[i]记录编号为i的平衡树的根。
每次合并时启发式合并,直接把小的树的每个点暴力insert到大树里。
查询正常操作即可。
为了方便可以把原来普通平衡树的0节点,改为第i课树的0节点是i,然后其他点的编号都加n,方便操作。
AC代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<queue>
#include<map>
#include<bitset>
#include<stack>
using namespace std;
const int maxn=2e5+5;
int fa[maxn],siz[maxn],n,m,rt[maxn],q;
struct node{
int fa,son[2],val,siz;
}tr[maxn];
inline int find(int x){
if(fa[x]==x) return x;
return fa[x]=find(fa[x]);
}
void init(int i,int fa){
tr[i].fa=fa;
tr[i].son[0]=tr[i].son[1]=0;
tr[i].siz=1;
}
void update(int x){
tr[x].siz=1;
if(tr[x].son[0]) tr[x].siz+=tr[tr[x].son[0]].siz;
if(tr[x].son[1]) tr[x].siz+=tr[tr[x].son[1]].siz;
}
void rotate(int x){
int y=tr[x].fa,z=tr[y].fa;
int c=(tr[y].son[1]==x);
tr[tr[x].son[!c]].fa=y;
tr[x].fa=tr[y].fa;
tr[y].fa=x;
if(z) tr[z].son[tr[z].son[1]==y]=x;
tr[y].son[c]=tr[x].son[!c];
tr[x].son[!c]=y;
update(y);
update(x);
}
void splay(int x,int goal){
if(x==goal) return;
while(tr[x].fa!=goal){
int y=tr[x].fa,z=tr[y].fa;
if(z!=goal) ((tr[y].son[0]==x)^(tr[z].son[0]==y))?rotate(x):rotate(y);
rotate(x);
}
if(goal<=n) rt[goal]=x;
}
void insert(int y,int id){
int x=rt[y];
while(1){
if(tr[x].son[tr[x].val<tr[id].val]) x=tr[x].son[tr[x].val<tr[id].val];
else{
init(id,x);
tr[x].son[tr[x].val<tr[id].val]=id;
splay(id,y);
return;
}
}
}
void del(int x,int y){
if(tr[x].son[0]) del(tr[x].son[0],y);
if(tr[x].son[1]) del(tr[x].son[1],y);
insert(y,x);
}
void merge(int x,int y){
int fx=find(x),fy=find(y);
if(fx==fy) return;
if(siz[fx]>siz[fy]) swap(fx,fy);
fa[fx]=fy;
siz[fy]+=siz[fx];
del(rt[fx],fy);
}
int getval(int x,int k){
if(tr[x].siz<k) return -1;
while(1){
if((tr[x].son[0]?tr[tr[x].son[0]].siz+1:1)==k){
return x-n;
}
if((tr[x].son[0]?tr[tr[x].son[0]].siz+1:1)<k){
k-=(tr[x].son[0]?tr[tr[x].son[0]].siz+1:1);
x=tr[x].son[1];
}else{
x=tr[x].son[0];
}
}
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>tr[i+n].val,tr[i+n].fa=i,fa[i]=fa[i+n]=i,siz[i]=1,rt[i]=i+n,tr[i+n].siz=1;
for(int i=1;i<=m;i++){
int u,v;
cin>>u>>v;
merge(u,v);
}
cin>>q;
for(int i=1;i<=q;i++){
char c;
int x,y;
cin>>c>>x>>y;
if(c=='Q'){
cout<<getval(rt[find(x)],y)<<endl;
}else{
merge(find(x),find(y));
}
}
return 0;
}