BZOJ3196二逼平衡树——线段树套平衡树(treap)

此为平衡树系列最后一道:二逼平衡树您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)

输入

第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继

样例输入

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

样例输出

2 4 3 4 9

提示

n,m<=50000   保证有序序列所有值在任何时刻满足[0,10^8]

这道题相对于普通平衡树就是把所有查询操作从整个数列改为在一个区间里查询,因此需要用线段树来维护区间,然后在区间上完成相应的操作。

线段树套平衡树就是在线段树上的每一个节点建一棵平衡树来维护这个节点所对应区间的信息,因为线段树最多有logn层,每个数在一层只出现一次,所以所有平衡树节点数是nlogn,时间复杂度是O(m*logn^2)。原理就是在线段树上先找到所查区间对应的分区间,在每个分区间节点的平衡树上查找信息再合并。

1、4、5操作不做说明和平衡树一样。3操作就是把这个数删掉再把改完值的数加进去。重点是2操作,需要二分答案(二分k的值),然后用1操作来验证。

树套树代码一般都很长而且不好调试但很好理解。

最后附上代码。

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<cstring>
#include<queue>
#include<vector>
using namespace std;
int n,m;
int tot;
int a[100010];
int l,r,k,opt,x;
struct intree
{
    int s[2];
    int w;
    int r;
    int v;
    int size;
};
intree treap[2000100];
int inbuild(int k)
{
    tot++;
    treap[tot].r=rand();
    treap[tot].v=k;
    treap[tot].size=1;
    treap[tot].w=1;
    treap[tot].s[0]=treap[tot].s[1]=0;
    return tot;
}
void updata(int i)
{
    treap[i].size=treap[i].w+treap[treap[i].s[0]].size+treap[treap[i].s[1]].size;
}
void rotate(int &x,int i)
{
    int p=treap[x].s[i];
    treap[x].s[i]=treap[p].s[i^1];
    treap[p].s[i^1]=x;
    updata(x);
    updata(p);
    x=p;
}
void insert(int &x,int k)
{
    if(!x)
    {
        x=inbuild(k);
    }
    else if(treap[x].v==k)
    {
        treap[x].w++;
    }
    else
    {
        if(treap[x].v<k)
        {
            insert(treap[x].s[1],k);
            if(treap[treap[x].s[1]].r>treap[x].r)
            {
                rotate(x,1);
            }
        }
        else
        {
            insert(treap[x].s[0],k);
            if(treap[treap[x].s[0]].r>treap[x].r)
            {
                rotate(x,0);
            }
        }
    }
    updata(x);
}
void del(int &x,int k)
{
    if(treap[x].v<k)
    {
        del(treap[x].s[1],k);
    }
    else if(treap[x].v>k)
    {
        del(treap[x].s[0],k);
    }
    else
    {
        if(treap[x].w>1)
        {
            treap[x].w--;
        }
        else
        {
            if(!treap[x].s[0]&&!treap[x].s[1])
            {
                x=0;
            }
            else if(!treap[x].s[0])
            {
                rotate(x,1);
                del(treap[x].s[0],k);
            }
            else if(!treap[x].s[1])
            {
                rotate(x,0);
                del(treap[x].s[1],k);
            }
            else
            {
                if(treap[treap[x].s[0]].r>treap[treap[x].s[1]].r)
                {
                    rotate(x,0);
                    del(treap[x].s[1],k);
                }
                else
                {
                    rotate(x,1);
                    del(treap[x].s[0],k);
                }
            }   
        }
    }
    if(x)
    {
        updata(x);
    }
}
int inrank(int x,int k)
{
    if(!x)
    {
        return 0;
    }
    if(treap[x].v>k)
    {
        return inrank(treap[x].s[0],k);
    }
    else if(treap[x].v==k)
    {
        return treap[treap[x].s[0]].size;
    }
    else
    {
        return inrank(treap[x].s[1],k)+treap[treap[x].s[0]].size+treap[x].w;
    }
}
int inpre(int x,int k)
{
    if(!x)
    {
        return -2147483647;
    }
    if(treap[x].v>=k)
    {
        return inpre(treap[x].s[0],k);
    }
    else
    {
        return max(treap[x].v,inpre(treap[x].s[1],k));
    }
}
int insuf(int x,int k)
{
    if(!x)
    {
        return 2147483647;
    }
    if(treap[x].v<=k)
    {
        return insuf(treap[x].s[1],k);
    }
    else
    {
        return min(treap[x].v,insuf(treap[x].s[0],k));
    }
}
struct outtree
{
    int l;
    int r;
    int root;
};
outtree tree[2000100];
void outbuild(int x,int l,int r)
{
    tree[x].l=l;
    tree[x].r=r;
    for(int i=l;i<=r;i++)
    {
        insert(tree[x].root,a[i]);
    }
    if(l!=r)
    {
        int mid=(l+r)>>1;
        outbuild(x*2,l,mid);
        outbuild(x*2+1,mid+1,r);
    }
}
void change(int i,int x,int y)
{
    del(tree[i].root,a[x]);
    insert(tree[i].root,y);
    if(tree[i].l==tree[i].r)
    {
        return ;
    }
    int mid=(tree[i].l+tree[i].r)>>1;
    if(mid<x)
    {
        change(i*2+1,x,y);
    }
    else
    {
        change(i*2,x,y);
    }
}
int outrank(int x,int l,int r,int k)
{
    if(tree[x].l>r||tree[x].r<l)
    {
        return 0;
    }
    if(tree[x].l>=l&&tree[x].r<=r)
    {
        return inrank(tree[x].root,k);
    }
    else
    {
        return outrank(x*2,l,r,k)+outrank(x*2+1,l,r,k);
    }
}
int outsum(int l,int r,int k)
{
    int L=0;
    int R=1e8;
    while(L<R)
    {
        int mid=(L+R+1)>>1;
        if(outrank(1,l,r,mid)<k)
        {
            L=mid;
        }
        else
        {
            R=mid-1;
        }
    }
    return R;
}
int outpre(int x,int l,int r,int k)
{
    if(tree[x].l>r||tree[x].r<l)
    {
        return -2147483647;
    }
    if(tree[x].l>=l&&tree[x].r<=r)
    {
        return inpre(tree[x].root,k);
    }
    else
    {
        return max(outpre(x*2,l,r,k),outpre(x*2+1,l,r,k));
    }
}
int outsuf(int x,int l,int r,int k)
{
    if(tree[x].l>r||tree[x].r<l)
    {
        return 2147483647;
    }
    if(tree[x].l>=l&&tree[x].r<=r)
    {
        return insuf(tree[x].root,k);
    }
    else
    {
        return min(outsuf(x*2,l,r,k),outsuf(x*2+1,l,r,k));
    }
}
int main()
{
    srand(12378);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
    }
    outbuild(1,1,n);
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&opt);
        if(opt==1)
        {
            scanf("%d%d%d",&l,&r,&k);
            printf("%d\n",outrank(1,l,r,k)+1);
        }
        else if(opt==2)
        {
            scanf("%d%d%d",&l,&r,&k);
            printf("%d\n",outsum(l,r,k));
        }
        else if(opt==3)
        {
            scanf("%d%d",&k,&x);
            change(1,k,x);
            a[k]=x;
        }
        else if(opt==4)
        {
            scanf("%d%d%d",&l,&r,&k);
            printf("%d\n",outpre(1,l,r,k));
        }
        else
        {
            scanf("%d%d%d",&l,&r,&k);
            printf("%d\n",outsuf(1,l,r,k));
        }
    }
    return 0;
}
posted @ 2018-05-25 18:49  The_Virtuoso  阅读(774)  评论(2编辑  收藏  举报