C03【模板】Splay P3369 普通平衡树
视频链接:261【模板】Splay P3369 普通平衡树_哔哩哔哩_bilibili
#include <iostream> using namespace std; #define ls(x) tr[x].ch[0] #define rs(x) tr[x].ch[1] const int N=1100010, INF=(1<<30)+1; struct node{ int ch[2]; //儿 int fa; //父 int v; //点权 int cnt; //点权次数 int siz; //子树大小 void init(int p,int v1){ fa=p, v=v1; cnt=siz=1; } }tr[N]; int root,tot; //根,节点个数 void pushup(int x){ //上传 tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt; } void rotate(int x){ //旋转 int y=tr[x].fa, z=tr[y].fa, k=tr[y].ch[1]==x; //y的右儿是x tr[z].ch[tr[z].ch[1]==y]=x, tr[x].fa=z; //z的儿是x,x的父是z tr[y].ch[k]=tr[x].ch[k^1], tr[tr[x].ch[k^1]].fa=y; //y的儿是x的异儿,x的异儿的父是y tr[x].ch[k^1]=y, tr[y].fa=x; //x的异儿是y,y的父是x pushup(y), pushup(x); //自底向上push } void splay(int x, int k){ //伸展 while(tr[x].fa!=k){ //折线转xx,直线转yx int y=tr[x].fa, z=tr[y].fa; if(z!=k) (ls(y)==x)^(ls(z)==y)?rotate(x):rotate(y); rotate(x); } if(!k) root=x; //k=0时,x转到根 } void insert(int v){ //插入 int x=root, p=0; //x走到空节点或走到目标点结束 while(x&&tr[x].v!=v) p=x,x=tr[x].ch[v>tr[x].v]; if(x) tr[x].cnt++; //目标点情况 else{ //空节点情况 x=++tot; tr[p].ch[v>tr[p].v]=x; tr[x].init(p,v); } splay(x, 0); } void find(int v){ //找到v并转到根 int x=root; while(tr[x].ch[v>tr[x].v]&&v!=tr[x].v) x=tr[x].ch[v>tr[x].v]; splay(x, 0); } int getpre(int v){ //前驱 find(v); int x=root; if(tr[x].v<v) return x; x=ls(x); while(rs(x)) x=rs(x); splay(x, 0); return x; } int getsuc(int v){ //后继 find(v); int x=root; if(tr[x].v>v) return x; x=rs(x); while(ls(x)) x=ls(x); splay(x, 0); return x; } void del(int v){ //删除 int pre=getpre(v); int suc=getsuc(v); splay(pre,0), splay(suc,pre); int del=tr[suc].ch[0]; if(tr[del].cnt>1) tr[del].cnt--, splay(del,0); else tr[suc].ch[0]=0, splay(suc,0); } int getrank(int v){ //排名 insert(v); int res=tr[tr[root].ch[0]].siz; del(v); return res; } int getval(int k){ //数值 int x=root; while(true){ if(k<=tr[ls(x)].siz) x=ls(x); else if(k<=tr[ls(x)].siz+tr[x].cnt) break; else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x); } splay(x, 0); return tr[x].v; } int main(){ insert(-INF);insert(INF); //哨兵 int n,op,x; scanf("%d", &n); while(n--){ scanf("%d%d", &op, &x); if(op==1) insert(x); else if(op==2) del(x); else if(op==3) printf("%d\n",getrank(x)); else if(op==4) printf("%d\n",getval(x+1)); else if(op==5) printf("%d\n",tr[getpre(x)].v); else printf("%d\n",tr[getsuc(x)].v); } }
#include <iostream> using namespace std; const int N=1100010, INF=(1<<30)+1; int ch[N][2],fa[N],val[N],cnt[N],siz[N],root,tot; //ch儿,fa父,val点权,cnt点权次数,siz子树大小,root根,tot节点个数 void pushup(int x){ //上传 siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x]; } void rotate(int x){ //旋转 int y=fa[x],z=fa[y],k=ch[y][1]==x; //y的右儿是x ch[z][ch[z][1]==y]=x, fa[x]=z; //z的儿是x,x的父是z ch[y][k]=ch[x][k^1], fa[ch[x][k^1]]=y; //y的儿是x的异儿,x的异儿的父是y ch[x][k^1]=y, fa[y]=x; //x的异儿是y,y的父是x pushup(y); pushup(x); //自底向上push } void splay(int x, int k){ //伸展 while(fa[x]!=k){ //折线转xx,直线转yx int y=fa[x], z=fa[y]; if(z!=k) (ch[y][0]==x)^(ch[z][0]==y)?rotate(x):rotate(y); rotate(x); } if(!k) root=x; //k=0时,x转到根 } void insert(int v){ //插入 int x=root, p=0; //x走到空节点或走到目标点结束 while(x&&val[x]!=v) p=x,x=ch[x][v>val[x]]; if(x) cnt[x]++; //目标点情况 else{ //空节点情况 x=++tot; ch[p][v>val[p]]=x;fa[x]=p; val[x]=v;cnt[x]=siz[x]=1; } splay(x,0); } void find(int v){ //找到v并转到根 int x=root; while(ch[x][v>val[x]]&&v!=val[x]) x=ch[x][v>val[x]]; splay(x,0); } int getpre(int v){ //前驱 find(v); int x=root; if(val[x]<v) return x; x=ch[x][0]; while(ch[x][1]) x=ch[x][1]; splay(x,0); return x; } int getsuc(int v){ //后继 find(v); int x=root; if(val[x]>v) return x; x=ch[x][1]; while(ch[x][0]) x=ch[x][0]; splay(x,0); return x; } void del(int v){ //删除 int pre=getpre(v); int suc=getsuc(v); splay(pre,0), splay(suc,pre); int del=ch[suc][0]; if(cnt[del]>1) cnt[del]--, splay(del,0); else ch[suc][0]=0, splay(suc,0); } int getrank(int v){ //排名 insert(v); int res=siz[ch[root][0]]; del(v); return res; } int getval(int k){ //数值 int x=root; while(true){ if(k<=siz[ch[x][0]]) x=ch[x][0]; else if(k<=siz[ch[x][0]]+cnt[x]) break; else k-=siz[ch[x][0]]+cnt[x], x=ch[x][1]; } splay(x,0); return val[x]; } int main(){ insert(-INF);insert(INF); //哨兵 int n,op,x; scanf("%d", &n); while(n--){ scanf("%d%d", &op, &x); if(op==1) insert(x); else if(op==2) del(x); else if(op==3) printf("%d\n",getrank(x)); else if(op==4) printf("%d\n",getval(x+1)); else if(op==5) printf("%d\n",val[getpre(x)]); else printf("%d\n",val[getsuc(x)]); } }
#include <iostream> using namespace std; #define ls(x) tr[x].s[0] #define rs(x) tr[x].s[1] const int N=1100010, INF=(1<<30)+1; struct node{ int s[2]; //左右儿子 int p; //父亲 int v; //节点权值 int cnt; //权值出现次数 int siz; //子树大小 void init(int p1,int v1){ p=p1, v=v1; cnt=siz=1; } }tr[N]; int root; //根节点编号 int idx; //节点个数 void pushup(int x){ tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt; } void rotate(int x){ int y=tr[x].p, z=tr[y].p; int k = tr[y].s[1]==x; tr[z].s[tr[z].s[1]==y] =x; tr[x].p = z; tr[y].s[k] = tr[x].s[k^1]; tr[tr[x].s[k^1]].p = y; tr[x].s[k^1] = y; tr[y].p = x; pushup(y), pushup(x); } void splay(int x, int k){ while(tr[x].p!=k){ int y=tr[x].p, z=tr[y].p; if(z!=k) // 折转底,直转中 (ls(y)==x)^(ls(z)==y) ? rotate(x) : rotate(y); rotate(x); } if(!k) root=x; } void insert(int v){ //插入 int x=root, p=0; while(x && tr[x].v!=v) p=x, x=tr[x].s[v>tr[x].v]; if(x) tr[x].cnt++; else{ x=++idx; if(p) tr[p].s[v>tr[p].v]=x; tr[x].init(p,v); } splay(x, 0); } void find(int v){ //找到v并转到根 int x=root; while(tr[x].s[v>tr[x].v]&&v!=tr[x].v) x=tr[x].s[v>tr[x].v]; splay(x, 0); } int getpre(int v){ //前驱 find(v); int x=root; if(tr[x].v<v) return x; x=ls(x); while(rs(x)) x=rs(x); splay(x, 0); return x; } int getsuc(int v){ //后继 find(v); int x=root; if(tr[x].v>v) return x; x=rs(x); while(ls(x)) x=ls(x); splay(x, 0); return x; } void del(int v){ //删除 int pre=getpre(v); int suc=getsuc(v); splay(pre,0), splay(suc,pre); int del=tr[suc].s[0]; if(tr[del].cnt>1) tr[del].cnt--, splay(del,0); else tr[suc].s[0]=0, splay(suc,0); } int getrank(int v){ //排名 insert(v); int res=tr[tr[root].s[0]].siz; del(v); return res; } int getval(int k){ //数值 int x=root; while(true){ if(k<=tr[ls(x)].siz) x=ls(x); else if(k<=tr[ls(x)].siz+tr[x].cnt) break; else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x); } splay(x, 0); return tr[x].v; } int main(){ insert(-INF);insert(INF); //哨兵 int n,t; scanf("%d%d", &n,&t); for(int i=1; i<=n; i++){ int x; scanf("%d", &x); insert(x); } int res=0, last=0; while(t--){ int op,x; scanf("%d%d", &op, &x); x^=last; if(op==1) insert(x); if(op==2) del(x); if(op==3) res^=(last=getrank(x)); if(op==4) res^=(last=getval(x+1)); if(op==5) res^=(last=tr[getpre(x)].v); if(op==6) res^=(last=tr[getsuc(x)].v); } printf("%d\n",res); return 0; }