splay板子
1, splay的一些基本操作.
- 使用前要插入$-INF,+INF$保证每个点的前驱后继存在.
- $get$函数在$x$存在时, 调用后, 根为$x$, 否则根为$x$的前驱或后继
const int N = 1e6+10; int n, tot, rt, sz; struct { int cnt,sz,fa,ch[2],v; } tr[N]; void pu(int x) { tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt; } void rot(int x) { int y=tr[x].fa,z=tr[y].fa; int f=tr[y].ch[1]==x; tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z; tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y; tr[x].ch[f^1]=y,tr[y].fa=x,pu(y); } void splay(int x, int s=0) { for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) { rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x); } if (!s) rt=x; } void get(int x) { int cur=rt; while (x!=tr[cur].v&&tr[cur].ch[x>tr[cur].v]) cur=tr[cur].ch[x>tr[cur].v]; splay(cur); } void insert(int x) { int cur=rt,p=0; while (cur&&x!=tr[cur].v) p=cur,cur=tr[cur].ch[x>tr[cur].v]; if (cur) ++tr[cur].cnt; else { cur=++tot; if (p) tr[p].ch[x>tr[p].v]=cur,tr[cur].fa=p; tr[cur].v=x,tr[cur].sz=tr[cur].cnt=1; } splay(cur); } int pre(int x) { get(x); if (tr[rt].v<=x) return rt; int cur=tr[rt].ch[0]; while (tr[cur].ch[1]) cur=tr[cur].ch[1]; return cur; } int nxt(int x) { get(x); if (tr[rt].v>=x) return rt; int cur=tr[rt].ch[1]; while (tr[cur].ch[0]) cur=tr[cur].ch[0]; return cur; } void erase(int x) { int s1=pre(x-1),s2=nxt(x+1); splay(s1),splay(s2,s1); int &cur=tr[s2].ch[0]; if (tr[cur].cnt>1) --tr[cur].cnt,splay(cur); else cur=0; }
2, splay插入区间,区间翻转等操作.
这时候splay维护的是每个下标对应的权值, 下标通过第k大来查询
- 使用前要调用$build(a,0,rt,1,2);$
const int N = 1e6+10; int n, rt, tot; int a[N]; struct _ { int sz,v,ch[2],fa,rev; } tr[N]; void pu(int o) { tr[o].sz=tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1; } void pd(int o) { if (tr[o].rev) { swap(tr[o].ch[0],tr[o].ch[1]); tr[tr[o].ch[0]].rev^=1; tr[tr[o].ch[1]].rev^=1; tr[o].rev=0; } } void rot(int x) { int y=tr[x].fa,z=tr[y].fa; int f=tr[y].ch[1]==x; tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z; tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y; tr[x].ch[f^1]=y,tr[y].fa=x,pu(y); } void splay(int x, int s=0) { for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) { rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x); } if (!s) rt=x; } int find(int x, int k) { pd(x); int s=tr[tr[x].ch[0]].sz; if (k==s+1) return x; if (k<=s) return find(tr[x].ch[0],k); return find(tr[x].ch[1],k-s-1); } void build(int *a, int f, int &o, int l, int r) { if (l>r) return; o = ++tot; tr[o].v = a[mid], tr[o].fa = f; build(s,o,tr[o].ch[0],l,mid-1); build(s,o,tr[o].ch[1],mid+1,r); pu(o); } void ins(int x, int n) { build(a,0,p,1,n); int s1=find(rt,x-1), s2=find(rt,x); splay(s1),splay(s2,s1); tr[s2].ch[0]=p,tr[p].fa=s2; pu(p),pu(s2); } void del(int x, int n) { int s1=find(rt,x-1), s2=find(rt,x+n); splay(s1),splay(s2,s1); tr[s2].ch[0]=0; pu(s1),pu(s2); } void reverse(int x, int n) { int s1=find(rt,x-1), s2=find(rt,x+n); splay(s1),splay(s2,s1); tr[tr[s2].ch[0]].rev^=1; }