bzoj3196 二逼平衡树——线段树套平衡树

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3196

人生中第一棵树套树!

写了一个晚上,成功卡时 9000ms+ 过了!

很要注意数组的大小,因为是树*树的大小嘛!

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
int const maxn=2e6+5,nn=5e4+5,inf=1e9;//2e6!!
int n,m,tot,siz[maxn],c[maxn][3],fa[maxn],key[maxn],mn,mx,a[nn];
struct N{
    int l,r,rt;
}t[nn<<2];
void update(int x){siz[x]=siz[c[x][0]]+siz[c[x][1]]+1;}
void rotate(int x)
{
    int y=fa[x],z=fa[y],d=(c[y][1]==x);
    if(z!=0)c[z][c[z][1]==y]=x;
    fa[x]=z; fa[y]=x; fa[c[x][d^1]]=y;
    c[y][d]=c[x][d^1]; c[x][d^1]=y; 
    update(y); update(x);
}
void splay(int x,int f)//fa[x]=f
{
    while(fa[x]!=f)
    {
        int y=fa[x],z=fa[y];
        if(z!=f)// z 而非 y 
        {
            if((c[y][1]==x)^(c[z][1]==y)) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
}
int find_pre(int x,int v)
{
    if(!x)return 0;
    if(key[x]>=v) return find_pre(c[x][0],v);
    else
    {
        int y=find_pre(c[x][1],v);
        if(y)return y; else return x;
    }
}
int find_k(int x,int k)
{
//    if(!x)return 0;
    if(siz[c[x][0]]==k-1)return x;
    if(siz[c[x][0]]>k-1)return find_k(c[x][0],k);
    return find_k(c[x][1],k-siz[c[x][0]]-1);
}
int find_suc(int x,int k)//后继 
{
    if(!x)return 0;
    if(key[x]<=k)return find_suc(c[x][1],k);
    else
    {
        int y=find_suc(c[x][0],k);
        if(y)return y; else return x;//说不定就是 x 
    }
}
int find_Suc(int x,int k)//相等中最小的 
{
    if(!x)return 0;
    if(key[x]<k)return find_Suc(c[x][1],k);
    else
    {
        int y=find_Suc(c[x][0],k);
        if(y)return y; else return x;//说不定就是 x 
    }
}
int query_k(int &x,int k)//查询 k 的排名 
{
    int p=find_Suc(x,k); splay(p,0);//相等中最小的
    x=p;
    return siz[c[p][0]]-1;// -inf
}
int query(int &x,int k)//二分用 
{
    int p=find_suc(x,k); splay(p,0);
    x=p;
    return siz[c[p][0]]-1;// -inf
}
void insert(int &x,int v)
{
    int p=find_pre(x,v); splay(p,0);
    int y=find_k(c[p][1],1); splay(y,p);//p
    tot++; siz[tot]=1; fa[tot]=y; key[tot]=v; c[y][0]=tot;
    update(y); update(p);
    x=p;//rt
}
void del(int &x,int v)
{
    int p=find_pre(x,v); splay(p,0);
    int y=find_k(c[p][1],2); splay(y,p);//p
    fa[c[y][0]]=0; c[y][0]=0;
    update(y); update(p);
    x=p;
}
void build(int x,int l,int r)
{
    t[x].l=l; t[x].r=r;
    t[x].rt=++tot; siz[tot]=2; fa[tot]=0; c[tot][1]=tot+1; key[tot]=-inf;
    tot++; siz[tot]=1; fa[tot]=tot-1; key[tot]=inf;
    for(int i=l;i<=r;i++)insert(t[x].rt,a[i]);
    if(l==r)return;
    int mid=(l+r)>>1;
    build(x<<1,l,mid); build(x<<1|1,mid+1,r);
}
int query_k(int x,int l,int r,int k)//查询 k 的排名 
{
    if(t[x].l>=l && t[x].r<=r)return query_k(t[x].rt,k);//<=k 的个数 
    int mid=(t[x].l+t[x].r)>>1,ret=0;//不是 l,r 
    if(mid>=l)ret+=query_k(x<<1,l,r,k);//l,r,无需 l,mid ,因为自带 t[x].l,t[x].r 
    if(mid<r)ret+=query_k(x<<1|1,l,r,k);
    return ret;
}
int query(int x,int l,int r,int k)//二分用 //<=mid的个数 
{
    if(t[x].l>=l && t[x].r<=r)return query(t[x].rt,k);
    int mid=(t[x].l+t[x].r)>>1,ret=0;//不是 l,r 
    if(mid>=l)ret+=query(x<<1,l,r,k);
    if(mid<r)ret+=query(x<<1|1,l,r,k);
    return ret;
}
void modify(int x,int p,int k)
{
    insert(t[x].rt,k); del(t[x].rt,a[p]); 
    if(t[x].l==t[x].r)return;//
    int mid=(t[x].l+t[x].r)>>1;
    if(p<=mid)modify(x<<1,p,k);
    else modify(x<<1|1,p,k);
}
int query_pre(int x,int l,int r,int k)
{
    if(t[x].l>=l && t[x].r<=r)return key[find_pre(t[x].rt,k)];//返回 key ,因为需要比较
    int mid=(t[x].l+t[x].r)>>1,ret=0;
    if(mid>=l)ret=max(ret,query_pre(x<<1,l,r,k));
    if(mid<r)ret=max(ret,query_pre(x<<1|1,l,r,k)); 
    return ret;
}
int query_suc(int x,int l,int r,int k)
{
    if(t[x].l>=l && t[x].r<=r)return key[find_suc(t[x].rt,k)];//返回 key ,因为需要比较
    int mid=(t[x].l+t[x].r)>>1,ret=inf;
    if(mid>=l)ret=min(ret,query_suc(x<<1,l,r,k));
    if(mid<r)ret=min(ret,query_suc(x<<1|1,l,r,k)); 
    return ret;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        mn=min(mn,a[i]); mx=max(mx,a[i]);
    }
    build(1,1,n);
    for(int i=1,op,l,r,k;i<=m;i++)
    {
        scanf("%d",&op);
        if(op!=3)scanf("%d%d%d",&l,&r,&k);
        if(op==1)printf("%d\n",query_k(1,l,r,k)+1);//+1
        if(op==2)
        {
            int L=mn,R=mx,ans=0;
            while(L<=R)
            {
                int mid=(L+R)>>1;
                if(query(1,l,r,mid)>=k)ans=mid,R=mid-1;// >也可,因为查询的是 <=mid 的个数 
                else L=mid+1;
            }
            printf("%d\n",ans);
        }
        if(op==3)
        {
            int pos; scanf("%d%d",&pos,&k);
            modify(1,pos,k);
            a[pos]=k;//!!!
        }
        if(op==4)printf("%d\n",query_pre(1,l,r,k));
        if(op==5)printf("%d\n",query_suc(1,l,r,k));
    }
    return 0;
}

 

posted @ 2018-06-24 21:55  Zinn  阅读(226)  评论(0编辑  收藏  举报