C03【模板】Splay P3369 普通平衡树

视频链接:261【模板】Splay P3369 普通平衡树_哔哩哔哩_bilibili

 

 

 

 

 

 

 

Luogu P3369【模板】普通平衡树

#include <iostream>
using namespace std;

#define ls(x) tr[x].ch[0]
#define rs(x) tr[x].ch[1]
const int N=1100010, INF=(1<<30)+1;
struct node{
  int ch[2]; //
  int fa; //
  int v;  //点权
  int cnt; //点权次数
  int siz; //子树大小
  void init(int p,int v1){
    fa=p, v=v1;
    cnt=siz=1;
  }
}tr[N];
int root,tot; //根,节点个数

void pushup(int x){ //上传
  tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt;
}
void rotate(int x){ //旋转
  int y=tr[x].fa, z=tr[y].fa, k=tr[y].ch[1]==x; //y的右儿是x
  tr[z].ch[tr[z].ch[1]==y]=x, tr[x].fa=z; //z的儿是x,x的父是z
  tr[y].ch[k]=tr[x].ch[k^1], tr[tr[x].ch[k^1]].fa=y; //y的儿是x的异儿,x的异儿的父是y
  tr[x].ch[k^1]=y, tr[y].fa=x; //x的异儿是y,y的父是x
  pushup(y), pushup(x); //自底向上push
}
void splay(int x, int k){ //伸展
  while(tr[x].fa!=k){ //折线转xx,直线转yx
    int y=tr[x].fa, z=tr[y].fa;
    if(z!=k) (ls(y)==x)^(ls(z)==y)?rotate(x):rotate(y);
    rotate(x);
  }
  if(!k) root=x; //k=0时,x转到根
}
void insert(int v){ //插入
  int x=root, p=0;
  //x走到空节点或走到目标点结束
  while(x&&tr[x].v!=v) p=x,x=tr[x].ch[v>tr[x].v];
  if(x) tr[x].cnt++; //目标点情况
  else{ //空节点情况
    x=++tot;
    tr[p].ch[v>tr[p].v]=x;
    tr[x].init(p,v);
  }
  splay(x, 0);
}
void find(int v){ //找到v并转到根
  int x=root;
  while(tr[x].ch[v>tr[x].v]&&v!=tr[x].v) 
    x=tr[x].ch[v>tr[x].v]; 
  splay(x, 0);
}
int getpre(int v){ //前驱
  find(v);
  int x=root;
  if(tr[x].v<v) return x;
  x=ls(x);
  while(rs(x)) x=rs(x);
  splay(x, 0);
  return x;
}
int getsuc(int v){ //后继
  find(v);
  int x=root;
  if(tr[x].v>v) return x;
  x=rs(x);
  while(ls(x)) x=ls(x);
  splay(x, 0);
  return x;
}
void del(int v){ //删除
  int pre=getpre(v);
  int suc=getsuc(v);
  splay(pre,0), splay(suc,pre);
  int del=tr[suc].ch[0];
  if(tr[del].cnt>1)
    tr[del].cnt--, splay(del,0);
  else
    tr[suc].ch[0]=0, splay(suc,0);
}
int getrank(int v){ //排名
  insert(v);
  int res=tr[tr[root].ch[0]].siz;
  del(v);
  return res;
}
int getval(int k){ //数值
  int x=root;
  while(true){
    if(k<=tr[ls(x)].siz) x=ls(x);
    else if(k<=tr[ls(x)].siz+tr[x].cnt) break;
    else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x);
  }
  splay(x, 0);
  return tr[x].v;
}
int main(){
  insert(-INF);insert(INF); //哨兵
  int n,op,x; scanf("%d", &n);
  while(n--){
    scanf("%d%d", &op, &x);
    if(op==1) insert(x);
    else if(op==2) del(x);
    else if(op==3) printf("%d\n",getrank(x));
    else if(op==4) printf("%d\n",getval(x+1));
    else if(op==5) printf("%d\n",tr[getpre(x)].v);
    else printf("%d\n",tr[getsuc(x)].v);
  }
}

 

#include <iostream>
using namespace std;

const int N=1100010, INF=(1<<30)+1;
int ch[N][2],fa[N],val[N],cnt[N],siz[N],root,tot;
//ch儿,fa父,val点权,cnt点权次数,siz子树大小,root根,tot节点个数

void pushup(int x){ //上传
  siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}
void rotate(int x){ //旋转
  int y=fa[x],z=fa[y],k=ch[y][1]==x; //y的右儿是x
  ch[z][ch[z][1]==y]=x, fa[x]=z; //z的儿是x,x的父是z
  ch[y][k]=ch[x][k^1], fa[ch[x][k^1]]=y; //y的儿是x的异儿,x的异儿的父是y
  ch[x][k^1]=y, fa[y]=x; //x的异儿是y,y的父是x
  pushup(y); pushup(x);  //自底向上push
}
void splay(int x, int k){ //伸展
  while(fa[x]!=k){ //折线转xx,直线转yx
    int y=fa[x], z=fa[y]; 
    if(z!=k) (ch[y][0]==x)^(ch[z][0]==y)?rotate(x):rotate(y);
    rotate(x);
  }
  if(!k) root=x; //k=0时,x转到根
}
void insert(int v){ //插入
  int x=root, p=0;
  //x走到空节点或走到目标点结束
  while(x&&val[x]!=v) p=x,x=ch[x][v>val[x]];
  if(x) cnt[x]++; //目标点情况
  else{ //空节点情况
    x=++tot;
    ch[p][v>val[p]]=x;fa[x]=p;
    val[x]=v;cnt[x]=siz[x]=1;
  }
  splay(x,0);
}
void find(int v){ //找到v并转到根
  int x=root;
  while(ch[x][v>val[x]]&&v!=val[x]) 
    x=ch[x][v>val[x]]; 
  splay(x,0);
}
int getpre(int v){ //前驱
  find(v);
  int x=root;
  if(val[x]<v) return x;
  x=ch[x][0];
  while(ch[x][1]) x=ch[x][1];
  splay(x,0);
  return x;
}
int getsuc(int v){ //后继
  find(v);
  int x=root;
  if(val[x]>v) return x;
  x=ch[x][1];
  while(ch[x][0]) x=ch[x][0];
  splay(x,0);
  return x;
}
void del(int v){ //删除
  int pre=getpre(v);
  int suc=getsuc(v);
  splay(pre,0), splay(suc,pre);
  int del=ch[suc][0];
  if(cnt[del]>1)
    cnt[del]--, splay(del,0);
  else
    ch[suc][0]=0, splay(suc,0);
}
int getrank(int v){ //排名
  insert(v);
  int res=siz[ch[root][0]];
  del(v);
  return res;
}
int getval(int k){ //数值
  int x=root;
  while(true){
    if(k<=siz[ch[x][0]]) x=ch[x][0];
    else if(k<=siz[ch[x][0]]+cnt[x]) break;
    else k-=siz[ch[x][0]]+cnt[x], x=ch[x][1];
  }
  splay(x,0);
  return val[x];
}
int main(){
  insert(-INF);insert(INF); //哨兵
  int n,op,x; scanf("%d", &n);
  while(n--){
    scanf("%d%d", &op, &x);
    if(op==1) insert(x);
    else if(op==2) del(x);
    else if(op==3) printf("%d\n",getrank(x));
    else if(op==4) printf("%d\n",getval(x+1));
    else if(op==5) printf("%d\n",val[getpre(x)]);
    else printf("%d\n",val[getsuc(x)]);
  }
}

 

Luogu P6136 【模板】普通平衡树(数据加强版)

#include <iostream>
using namespace std;

#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]
const int N=1100010, INF=(1<<30)+1;
struct node{
  int s[2]; //左右儿子
  int p; //父亲
  int v; //节点权值
  int cnt; //权值出现次数
  int siz; //子树大小
  void init(int p1,int v1){
    p=p1, v=v1;
    cnt=siz=1;
  }
}tr[N];
int root; //根节点编号
int idx; //节点个数

void pushup(int x){
  tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt;
}
void rotate(int x){
  int y=tr[x].p, z=tr[y].p;
  int k = tr[y].s[1]==x;
  tr[z].s[tr[z].s[1]==y] =x;
  tr[x].p = z;  
  tr[y].s[k] = tr[x].s[k^1];
  tr[tr[x].s[k^1]].p = y;
  tr[x].s[k^1] = y;
  tr[y].p = x;
  pushup(y), pushup(x);
}
void splay(int x, int k){
  while(tr[x].p!=k){
    int y=tr[x].p, z=tr[y].p;
    if(z!=k)   // 折转底,直转中
      (ls(y)==x)^(ls(z)==y)
        ? rotate(x) : rotate(y);
    rotate(x);
  }
  if(!k) root=x;
}
void insert(int v){ //插入
  int x=root, p=0;
  while(x && tr[x].v!=v)
    p=x, x=tr[x].s[v>tr[x].v];
  if(x) tr[x].cnt++;
  else{
    x=++idx;
    if(p) tr[p].s[v>tr[p].v]=x;
    tr[x].init(p,v);
  }
  splay(x, 0);
}
void find(int v){ //找到v并转到根
  int x=root;
  while(tr[x].s[v>tr[x].v]&&v!=tr[x].v) 
    x=tr[x].s[v>tr[x].v]; 
  splay(x, 0);
}
int getpre(int v){ //前驱
  find(v);
  int x=root;
  if(tr[x].v<v) return x;
  x=ls(x);
  while(rs(x)) x=rs(x);
  splay(x, 0);
  return x;
}
int getsuc(int v){ //后继
  find(v);
  int x=root;
  if(tr[x].v>v) return x;
  x=rs(x);
  while(ls(x)) x=ls(x);
  splay(x, 0);
  return x;
}
void del(int v){ //删除
  int pre=getpre(v);
  int suc=getsuc(v);
  splay(pre,0), splay(suc,pre);
  int del=tr[suc].s[0];
  if(tr[del].cnt>1)
    tr[del].cnt--, splay(del,0);
  else
    tr[suc].s[0]=0, splay(suc,0);
}
int getrank(int v){ //排名
  insert(v);
  int res=tr[tr[root].s[0]].siz;
  del(v);
  return res;
}
int getval(int k){ //数值
    int x=root;
    while(true){
        if(k<=tr[ls(x)].siz) x=ls(x);
        else if(k<=tr[ls(x)].siz+tr[x].cnt) break;
        else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x);
    }
    splay(x, 0);
    return tr[x].v;
}
int main(){
  insert(-INF);insert(INF); //哨兵
  int n,t; scanf("%d%d", &n,&t);
    for(int i=1; i<=n; i++){
        int x; scanf("%d", &x);
        insert(x);
    }  
    int res=0, last=0;
  while(t--){
    int op,x; scanf("%d%d", &op, &x);
    x^=last;    
    if(op==1) insert(x);
    if(op==2) del(x);
    if(op==3) res^=(last=getrank(x));
    if(op==4) res^=(last=getval(x+1));
    if(op==5) res^=(last=tr[getpre(x)].v);
    if(op==6) res^=(last=tr[getsuc(x)].v);
  }
  printf("%d\n",res);
  return 0;
}

 

posted @ 2022-07-19 20:07  董晓  阅读(2190)  评论(4编辑  收藏  举报