花了一天时间调试终于过了。网上非指针的splay太少了,学长大佬们拒绝指导,只能靠自己了。
其实splay是为了保证时间复杂度,所以每次操作进行一次splay。
splay共三种操作:zig,zig_zig,zig_zag。
设当前点为x,y=fa[x],z=fa[y]
1.zig
y为root时直接rotate(x)
2.zig_zig
当x,y,z构成一条链时,先rotate(y),再rotate(x)
3.zig_zag
当非前两种情况时,做两次rotate(x)
即
void splay(int x) { for(int fa;(fa=tr[x].f);rotate(x)) if(tr[fa].f) rotate((get(fa)==get(x))?fa:x); root=x; }
ac代码:
#include<bits/stdc++.h> #define maxn 100005 #define ls tr[x].ch[0] #define rs tr[x].ch[1] using namespace std; struct node{ int ch[2],f,key,cnt,size; }tr[maxn]; int sz,root,n; void clear(int x) { tr[x].ch[1]=tr[x].ch[0]=tr[x].cnt=tr[x].f=tr[x].key=tr[x].size=0; } int get(int x) { return tr[tr[x].f].ch[1]==x; } void update(int x) { if(x) { tr[x].size=tr[x].cnt; if(ls)tr[x].size+=tr[ls].size; if(rs)tr[x].size+=tr[rs].size; } } void rotate(int x) { int y=tr[x].f,z=tr[y].f,whi=get(x); tr[y].ch[whi]=tr[x].ch[whi^1];tr[tr[y].ch[whi]].f=y; tr[x].ch[whi^1]=y; tr[x].f=z; if(z)tr[z].ch[get(y)]=x; tr[y].f=x; update(y);update(x); } void splay(int x) { for(int fa;(fa=tr[x].f);rotate(x)) if(tr[fa].f) rotate((get(fa)==get(x))?fa:x); root=x; } void insert(int x) { if(root==0){root=++sz;tr[sz].key=x;tr[sz].cnt=tr[sz].size=1;return;} int now=root; while(1) { if(tr[now].key==x){tr[now].cnt++;tr[now].size++;splay(now);return;} int fa=now; now=tr[now].ch[tr[now].key<x]; if(now==0) { tr[++sz].f=fa; tr[sz].cnt=tr[sz].size=1; tr[sz].key=x; tr[fa].ch[tr[fa].key<x]=sz; update(fa); splay(sz); return; } } } int find_pm(int x) { int now=root,ans=0; while(1) { if(tr[now].key<=x) { ans+=(tr[now].ch[0]?tr[tr[now].ch[0]].size:0); if(tr[now].key==x){splay(now);return ans+1;} ans+=tr[now].cnt; now=tr[now].ch[1]; } else now=tr[now].ch[0]; } } int find_wz(int x) { int now=root; while(1) { int lsz=tr[tr[now].ch[0]].size; if(lsz>=x)now=tr[now].ch[0]; else if(lsz+tr[now].cnt>=x)return tr[now].key; else { x-=lsz+tr[now].cnt; now=tr[now].ch[1]; } } } int pre() { int now=tr[root].ch[0],x=tr[root].key; while(tr[now].ch[1]!=0) { now=tr[now].ch[1]; } return now; } int nex() { int now=tr[root].ch[1],x=tr[root].key; while(tr[now].ch[0]!=0) { now=tr[now].ch[0]; } return now; } void del(int x) { int w=find_pm(x); if(tr[root].cnt>1){tr[root].cnt--;tr[root].size--;return;} if(tr[root].ch[0]==0&&tr[root].ch[1]==0){clear(root);root=0;return;} if(tr[root].ch[0]==0) { int newroot=tr[root].ch[1]; clear(root); root=newroot; tr[root].f=0; return; } else if(tr[root].ch[1]==0) { int newroot=tr[root].ch[0]; clear(root); root=newroot; tr[root].f=0; return; } else { int nowroot=pre(),oldroot=root; splay(nowroot); tr[nowroot].ch[1]=tr[oldroot].ch[1]; tr[tr[oldroot].ch[1]].f=nowroot; root=nowroot; tr[root].f=0; clear(oldroot); update(root); } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) { int temp,x; scanf("%d%d",&temp,&x); if(temp==1)insert(x); if(temp==2)del(x); if(temp==3)printf("%d\n",find_pm(x)); if(temp==4)printf("%d\n",find_wz(x)); if(temp==5)insert(x),printf("%d\n",tr[pre()].key),del(x); if(temp==6)insert(x),printf("%d\n",tr[nex()].key),del(x); } return 0; }