Splay

二叉查找树,对于任意一个节点,该节点的关键码大于它的左子树中任意节点的关键码,该节点的关键码小于它的右子树中任意节点的关键码,且没有键值相等的点

二叉查找树的中序遍历是一个关键码单调递增的节点序列

数组及变量

\(fa[i]:\) 节点\(i\)的父节点

\(son[i][0]:\) 节点\(i\)的左儿子

\(son[i][1]:\) 节点\(i\)的右儿子

\(val[i]:\) 节点\(i\)的关键字

\(siz[i]:\) 以节点\(i\)为根的子树元素个数

\(cnt[i]:\) 节点\(i\)所表示的元素的出现次数

\(tot:\) 共有多少元素

\(root:\) 树的根

函数

\(check:\) 判断节点\(x\)是它父亲的左儿子还是右儿子

\(pushup:\) 更新节点\(x\)\(siz\)

\(rotate:\) 将是左儿子的右旋,是右儿子的左旋

\(splay :\) 进行伸展,不断\(rotate\)直到达到目标状态

\(insert:\) 插入一个值

\(find:\) 查找\(x\)的位置,并将其旋转到根节点

\(query\_rnk:\) 查询\(x\)的排名

\(query\_val:\) 查询排名为\(x\)的数

\(get:\) \(k=0\)时,求\(x\)的前驱,\(k=1\)时,求\(x\)的后继

\(del:\) 删除为\(x\)的数

\(code\)

bool check(int x)
{
    return ch[fa[x]][1]==x;
}
void pushup(int x)
{
    siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}
void rotate(int x)
{
    int y=fa[x],z=fa[y],k=check(x),w=ch[x][k^1];
    ch[z][check(y)]=x,ch[x][k^1]=y,ch[y][k]=w;
    fa[w]=y,fa[x]=z,fa[y]=x;
    pushup(y),pushup(x);
}
void splay(int x,int goal)
{
    for(int y;fa[x]!=goal;rotate(x))
        if(fa[y=fa[x]]!=goal)
            rotate(check(x)^check(y)?x:y);
    if(!goal) root=x;
}
void insert(int x)
{
    int p=root,f=0;
    while(p&&val[p]!=x) f=p,p=ch[p][val[p]<x];
    if(p) cnt[p]++;
    else p=++tot,ch[f][val[f]<x]=p,fa[p]=f,val[p]=x,cnt[p]=1;
    splay(p,0);
}
void find(int x)
{
    int p=root;
    while(ch[p][val[p]<x]&&x!=val[p]) p=ch[p][val[p]<x];
    splay(p,0);
}
int query_rnk(int x)
{
    find(x);
    return siz[ch[root][0]];
}
int query_val(int x)
{
    x++;
    int p=root;
    while(1)
    {
        if(x<=siz[ch[p][0]]) p=ch[p][0];
        else
        {
            x-=siz[ch[p][0]]+cnt[p];
            if(x<=0) return val[p];
            p=ch[p][1];
        }
    }
}
int get(int x,int k)
{
    find(x);
    if(val[root]>x&&k) return root;
    if(val[root]<x&&!k) return root;
    int p=ch[root][k];
    while(ch[p][k^1]) p=ch[p][k^1];
    return p;
}
void del(int x)
{
    int pre=get(x,0),nxt=get(x,1);
    splay(pre,0),splay(nxt,pre);
    int d=ch[nxt][0];
    if(cnt[d]>1) cnt[d]--,splay(d,0);
    else ch[nxt][0]=0;
}

......

insert(inf),insert(-inf);

insert(a)
del(a)
query_rnk(a)
query_val(a)
key[get(a,0)]
key[get(a,1)]

\(Splay\)进行序列操作,按序列编号为关键字建二叉搜索树,二叉搜索树的中序遍历为原序列

数列

维护一个数列,共 \(7\) 种操作:

I. INSERT x n a1 a2 .. an 在第 \(x\) 个数后插入 \(n\) 个数分别为 \(a_1\dots a_n\)

II. DELETE x n 删除第 \(x\) 个数开始的 \(n\) 个数。

III. REVERSE x n 翻转第 \(x\) 个数开始的 \(n\) 个数的区间。

IV. MAKE-SAME x n t 将第 \(x\) 个数开始的 \(n\) 个数统一改为 \(t\)

V. GET-SUM x n 输出第 \(x\) 个数开始的 \(n\) 个数的和。

VI. GET x 输出第 \(x\) 个数的值。

VII. MAX-SUM x n 输出第 \(x\) 个数开始的 \(n\) 个数的最大连续子序列和。

\(code:\)

bool check(int x)
{
    return ch[fa[x]][1]==x;
}
void pushup(int x)
{
    int ls=ch[x][0],rs=ch[x][1];
    siz[x]=siz[ls]+siz[rs]+1;
    sum[x]=sum[ls]+sum[rs]+val[x];
    lm[x]=max(lm[ls],sum[ls]+val[x]+lm[rs]);
    rm[x]=max(rm[rs],sum[rs]+val[x]+rm[ls]);
    ma[x]=max(val[x]+lm[rs]+rm[ls],max(ma[ls],ma[rs]));
}
void pushr(int x)
{
    rev[x]^=1,swap(ch[x][0],ch[x][1]),swap(lm[x],rm[x]);
}
void pushv(int x,int v)
{
    if(!x) return;
    tag[x]=1,val[x]=v,sum[x]=v*siz[x];
    lm[x]=rm[x]=max(sum[x],0),ma[x]=max(sum[x],val[x]);
}
void pushdown(int x)
{
    int ls=ch[x][0],rs=ch[x][1];
    if(tag[x]) pushv(ls,val[x]),pushv(rs,val[x]);
    if(rev[x]) pushr(ls),pushr(rs);
    tag[x]=rev[x]=0;
}
int add()
{
    int x=top?st[top--]:++tot;
    fa[x]=ch[x][0]=ch[x][1]=rev[x]=siz[x]=tag[x]=0;
    return x;
}
void build(int l,int r,int &x,int *a)
{
    x=add();
    int mid=(l+r)>>1;
    lm[x]=rm[x]=max(a[mid],0);
    val[x]=ma[x]=sum[x]=a[mid];
    if(l<mid) build(l,mid-1,ch[x][0],a);
    if(r>mid) build(mid+1,r,ch[x][1],a);
    fa[ch[x][0]]=fa[ch[x][1]]=x;
    pushup(x);
}
void rotate(int x)
{
    int y=fa[x],z=fa[y],k=check(x),w=ch[x][k^1];
    ch[z][check(y)]=x,ch[x][k^1]=y,ch[y][k]=w;
    fa[w]=y,fa[x]=z,fa[y]=x;
    pushup(y),pushup(x);
}
void splay(int x,int goal)
{
    for(int y;fa[x]!=goal;rotate(x))
        if(fa[y=fa[x]]!=goal)
            rotate(check(x)^check(y)?x:y);
    if(!goal) root=x;
}
int kth(int x,int rk)
{
    pushdown(x);
    int ls=ch[x][0],rs=ch[x][1];
    if(rk==siz[ls]+1) return x;
    if(rk<=siz[ls]) return kth(ls,rk);
    return kth(rs,rk-siz[ls]-1);
}
void split(int l,int r)
{
    l=kth(root,l-1),r=kth(root,r+1),splay(l,0),splay(r,l);
}
void insert(int x,int num)
{
    int t,p;
    build(1,num,t,c);
    split(x+1,x);
    p=ch[root][1];
    ch[p][0]=t,fa[t]=p;
    pushup(p),pushup(root);
}
void del(int x)
{
    if(!x) return;
    st[++top]=x;
    del(ch[x][0]),del(ch[x][1]);
}
void erase(int l,int r)
{
    int p;
    split(l,r);
    p=ch[root][1];
    del(ch[p][0]),ch[p][0]=0;
    pushup(p),pushup(root);
}
void cover(int l,int r,int v)
{
    int p;
    split(l,r);
    p=ch[root][1];
    pushv(ch[p][0],v);
    pushup(p),pushup(root);
}
void reverse(int l,int r)
{
    int p;
    split(l,r);
    p=ch[root][1];
    pushr(ch[p][0]);
    pushup(p),pushup(root);
}
int query_sum(int l,int r)
{
    int p;
    split(l,r);
    p=ch[root][1];
    return sum[ch[p][0]];
}
int query_max(int l,int r)
{
    int p;
    split(l,r);
    p=ch[root][1];
    return ma[ch[p][0]];
}

......

if(opt=="GET") read(x),printf("%d\n",val[kth(root,x+1)]);
else read(x),read(num),x++;
if(opt=="INSERT")
{
    for(int i=1;i<=num;++i) read(c[i]);
    insert(x,num);
}
if(opt=="DELETE") erase(x,x+num-1);
if(opt=="REVERSE") reverse(x,x+num-1);
if(opt=="MAKE-SAME") read(v),cover(x,x+num-1,v);
if(opt=="GET-SUM") printf("%d\n",query_sum(x,x+num-1));
if(opt=="MAX-SUM") printf("%d\n",query_max(x,x+num-1));
posted @ 2020-01-22 20:12  lhm_liu  阅读(268)  评论(0编辑  收藏  举报