平衡树板子
以P3369 普通平衡树为例。
有旋Treap
#include <bits/stdc++.h>
using namespace std;
const int maxn=100010,inf=2147483645;
int q,opt,tot,root,x;
struct treap{
int l,r,dat,val,siz,cnt;
}tr[maxn*5];
inline int read(){
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-') w=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
s=(s<<1)+(s<<3)+(ch^48);
ch=getchar();
}
return s*w;
}
inline int New(int val){
tr[++tot].val=val;
tr[tot].dat=rand();
tr[tot].cnt=tr[tot].siz=1;
//新建点
return tot;
}
inline void update(int p){
tr[p].siz=tr[tr[p].l].siz+tr[tr[p].r].siz+tr[p].cnt;
}
inline void build(){
root=New(-inf);tr[root].r=New(inf);
update(root);
//建树 更新信息
}
int getrank(int p,int val){
if(!p) return 0;
if(val==tr[p].val) return tr[tr[p].l].siz+1;
//查询到该值,返回左子树的大小+1(即排名
if(val<tr[p].val) return getrank(tr[p].l,val);
return getrank(tr[p].r,val)+tr[tr[p].l].siz+tr[p].cnt;
}
int getval(int p,int rank){
if(!p) return inf;
if(rank<=tr[tr[p].l].siz) return getval(tr[p].l,rank);
//必须写在前面
if(rank<=tr[tr[p].l].siz+tr[p].cnt) return tr[p].val;
//若排名在这个区间范围内可以直接返回值
return getval(tr[p].r,rank-tr[tr[p].l].siz-tr[p].cnt);
}
inline void zig(int &p){
//注意引用 右旋
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
update(tr[p].r),update(p);
}
inline void zag(int &p){
//注意引用 左旋
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
update(tr[p].l),update(p);
}
inline void insert(int &p,int val){
//注意引用 插入
if(!p){
p=New(val);
return;
}
if(val==tr[p].val){
tr[p].cnt++;
update(p);
return;
}
if(val<tr[p].val){
insert(tr[p].l,val);
if(tr[p].dat<tr[tr[p].l].dat) zig(p);
//在这里右旋 不满足堆性质
}else{
insert(tr[p].r,val);
if(tr[p].dat<tr[tr[p].r].dat) zag(p);
//在这里左旋
}
update(p);
//注意更新
}
inline int getpre(int val){
int ans=1,p=root;
while(p){
if(val==tr[p].val){
if(tr[p].l){
p=tr[p].l;
while(tr[p].r) p=tr[p].r;
ans=p;
}
break;
}
if(val>tr[p].val&&tr[p].val>tr[ans].val) ans=p;
p=val<tr[p].val?tr[p].l:tr[p].r;
}
return tr[ans].val;
}
inline int getnxt(int val){
int ans=2,p=root;
while(p){
if(val==tr[p].val){
if(tr[p].r){
p=tr[p].r;
while(tr[p].l) p=tr[p].l;
ans=p;
}
break;
}
if(val<tr[p].val&&tr[p].val<tr[ans].val) ans=p;
p=val<tr[p].val?tr[p].l:tr[p].r;
}
return tr[ans].val;
}
void remove(int &p,int val){
//注意引用 删除
if(!p) return;
if(val==tr[p].val){
if(tr[p].cnt>1){
tr[p].cnt--;
update(p);
return;
}
if(tr[p].l||tr[p].r){
if(!tr[p].r||tr[tr[p].l].dat>tr[tr[p].r].dat) zig(p),remove(tr[p].r,val);
else zag(p),remove(tr[p].l,val);
update(p);
//注意更新
}else p=0;
return;
}
val<tr[p].val?remove(tr[p].l,val):remove(tr[p].r,val);
update(p);
}
int main(){
build();q=read();
while(q--){
opt=read();x=read();
if(opt==1) insert(root,x);
else if(opt==2) remove(root,x);
else if(opt==3) printf("%d\n",getrank(root,x)-1);
else if(opt==4) printf("%d\n",getval(root,x+1));
else if(opt==5) printf("%d\n",getpre(x));
else printf("%d\n",getnxt(x));
}
return 0;
}
无旋Treap
#include <bits/stdc++.h>
#define lson(x) tr[(x)].l
#define rson(x) tr[(x)].r
using namespace std;
const int maxn=2e5+20;
int q,opt,tot,root;
struct N_treap{
int l,r,val,cnt,dat,siz;
}tr[maxn*5];
inline void pushup(int p){
tr[p].siz=tr[p].cnt+tr[lson(p)].siz+tr[rson(p)].siz;
}
inline int New(int val){
tr[++tot].val=val;
tr[tot].dat=rand();
lson(tot)=rson(tot)=0;
tr[tot].cnt=tr[tot].siz=1;
return tot;
}
inline int merge(int l,int r){
if(!l||!r) return l+r;
if(tr[l].dat>tr[r].dat){
rson(l)=merge(rson(l),r);
pushup(l);
return l;
}else{
lson(r)=merge(l,lson(r));
pushup(r);
return r;
}
}
inline void spilt(int p,int val,int &l,int &r){
if(!p) l=r=0;
else{
if(tr[p].val<=val){
l=p;
spilt(rson(l),val,rson(l),r);
pushup(l);
}else{
r=p;
spilt(lson(r),val,l,lson(r));
pushup(r);
}
}
}
inline void insert(int val){
int l,r;spilt(root,val,l,r);
root=merge(merge(l,New(val)),r);
}
inline void remove(int val){
int l,r,s;spilt(root,val-1,l,r);
spilt(r,val,r,s);
r=merge(lson(r),rson(r));
root=merge(merge(l,r),s);
return;
}
inline int getrank(int val){
int l,r;
spilt(root,val-1,l,r);
int res=tr[l].siz+1;
root=merge(l,r);
return res;
}
inline int kth(int p,int k){
if(k<=tr[lson(p)].siz) return kth(lson(p),k);
k-=tr[lson(p)].siz+tr[p].cnt;
if(k<=0) return p;
else return kth(rson(p),k);
}
inline int Kth(int k){
return tr[kth(root,k)].val;
}
inline int pre(int val){
int l,r;
spilt(root,val-1,l,r);
int res=tr[kth(l,tr[l].siz)].val;
root=merge(l,r);
return res;
}
inline int suc(int val){
int l,r;
spilt(root,val,l,r);
int res=tr[kth(r,1)].val;
root=merge(l,r);
return res;
}
inline int read(){
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-') w=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
s=(s<<1)+(s<<3)+(ch^48);
ch=getchar();
}
return s*w;
}
int main(){
q=read();
while(q--){
int dos;
opt=read();dos=read();
if(opt==1) insert(dos);
else if(opt==2) remove(dos);
else if(opt==3) printf("%d\n",getrank(dos));
else if(opt==4) printf("%d\n",Kth(dos));
else if(opt==5) printf("%d\n",pre(dos));
else printf("%d\n",suc(dos));
}
return 0;
}
Splay
#include <bits/stdc++.h>
#define lson(x) tr[(x)].son[0]
#define rson(x) tr[(x)].son[1]
#define fu(x) tr[(x)].fa
using namespace std;
namespace Broken_Eclipse{
inline int read(){
int s=0,w=1;char ch;
while(!isdigit(ch=getchar())) if(ch=='-') w=-1;
do s=s*10+(ch^48);while(isdigit(ch=getchar()));
return s*w;
}
}
using namespace Broken_Eclipse;
const int maxn=100010;
int root,tot,q,opt,x;
struct splay{
int son[2],val,cnt,siz,fa;
}tr[maxn*5];
inline void update(int p){
tr[p].siz=tr[lson(p)].siz+tr[rson(p)].siz+tr[p].cnt;
}
inline bool get(int p){
return p==rson(fu(p));
}
inline void clear(int p){
lson(p)=rson(p)=fu(p)=tr[p].val=tr[p].siz=tr[p].cnt=0;
}
inline void poer(int p){
int y=fu(p),z=fu(y),chk=get(p);
tr[y].son[chk]=tr[p].son[chk^1];
if(tr[p].son[chk^1]) fu(tr[p].son[chk^1])=y;
tr[p].son[chk^1]=y;
fu(y)=p;
fu(p)=z;
if(z) tr[z].son[y==rson(z)]=p;
update(p);
update(y);
return;
}
inline void splay(int p){
for(int f=fu(p);f=fu(p),f;poer(p))
if(fu(f)) poer(get(p)==get(f)?f:p);
root=p;
}
inline void insert(int val){
if(!root){
tr[++tot].val=val;
tr[tot].cnt++;
root=tot;
update(root);
return;
}
int cur=root,f=0;
while(1){
if(tr[cur].val==val){
tr[cur].cnt++;
update(cur);
update(f);
splay(cur);
break;
}
f=cur;
cur=tr[cur].son[tr[cur].val<val];
if(!cur){
tr[++tot].val=val;
tr[tot].cnt++;
fu(tot)=f;
tr[f].son[tr[f].val<val]=tot;
update(tot);
update(f);
splay(tot);
break;
}
}
}
inline int getrank(int val){
int res=0,cur=root;
while(1){
if(val<tr[cur].val) cur=lson(cur);
else{
res+=tr[lson(cur)].siz;
if(val==tr[cur].val){
splay(cur);
return res+1;
}
res+=tr[cur].cnt;
cur=rson(cur);
}
}
}
inline int getval(int rank){
int cur=root;
while(1){
if(lson(cur)&&rank<=tr[lson(cur)].siz) cur=lson(cur);
else{
rank-=tr[cur].cnt+tr[lson(cur)].siz;
if(rank<=0){
splay(cur);
return tr[cur].val;
}
cur=rson(cur);
}
}
}
inline int getpre(){
int cur=lson(root);
if(!cur) return cur;
while(rson(cur)) cur=rson(cur);
splay(cur);
return cur;
}
inline int getnxt(){
int cur=rson(root);
if(!cur) return cur;
while(lson(cur)) cur=lson(cur);
splay(cur);
return cur;
}
inline void del(int val){
getrank(val);
if(tr[root].cnt>1){
tr[root].cnt--;
update(root);
return;
}
if(!lson(root)&&!rson(root)){
clear(root);
root=0;
return;
}
if(!lson(root)){
int cur=root;
root=rson(root);
fu(root)=0;
clear(cur);
return;
}
if(!rson(root)){
int cur=root;
root=lson(root);
fu(root)=0;
clear(cur);
return;
}
int cur=root;
int ze=getpre();
fu(rson(cur))=ze;
rson(ze)=rson(cur);
clear(cur);
update(root);
return;
}
int main(){
q=read();
while(q--){
opt=read();x=read();
if(opt==1) insert(x);
else if(opt==2) del(x);
else if(opt==3) printf("%d\n",getrank(x));
else if(opt==4) printf("%d\n",getval(x));
else if(opt==5) insert(x),printf("%d\n",tr[getpre()].val),del(x);
else insert(x),printf("%d\n",tr[getnxt()].val),del(x);
}
return 0;
}