树套树(splay套线段树) -AcWing 2476

树套树(splay套线段树) -AcWing 2476

本来想着用multiset套线段树的,结果一直T。改成常数小的splay才过,写完人都傻了^^

/*
 splay套线段树
 */

#include <bits/stdc++.h>
using namespace std;

const int N = 5e4+5;
const int M = 1e7+5;
const int INF = 0x3fffffff;

struct SplayNode{
    int s[2],p,v;
    int size;
    void init(int _v,int _p){
        v = _v;
        p = _p;
    }
};


int n,m,op,L,R,X;
int idx;
int root[N<<2];  // splay num
SplayNode tr[M]; // splay nodes

int arr[N];

inline int ws(int x){
    return tr[tr[x].p].s[1] == x;
}

inline void push_up(int x){
    tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}

inline void rotate(int x){
    int y = tr[x].p;
    int z = tr[y].p;
    int k = ws(x);
    
    // modify z
    tr[z].s[ws(y)] = x;
    
    // modify y
    tr[y].s[k] = tr[x].s[k^1];
    tr[y].p = x;
    
    // modify m
    tr[tr[x].s[k^1]].p = y;
    
    // modify x
    tr[x].p = z;
    tr[x].s[k^1] = y;
    
    push_up(y);
    push_up(x);
}

inline void splay(int x,int k,int tr_id){
    while(tr[x].p != k){
        int y = tr[x].p;
        int z = tr[y].p;
        if(z != k){
            if(ws(x) ^ ws(y)){
                rotate(x);
            }else{
                rotate(y);
            }
        }
        rotate(x);
    }
    if(!k) root[tr_id] = x;
}

inline int insert(int v,int tr_id){
    int u = root[tr_id];
    int p = 0;
    while(u){
        p = u;
        u = tr[u].s[v > tr[u].v];
    }
    u = ++idx;
    if(p) tr[p].s[v > tr[p].v] = u;
    tr[u].init(v, p);
    splay(u, 0, tr_id);
    return u;
}

inline int count_less(int v,int tr_id){
    int u = root[tr_id], cnt = 0;
    while(u){
        if(tr[u].v < v){
            cnt += tr[tr[u].s[0]].size + 1;
            u = tr[u].s[1];
        }else{
            u = tr[u].s[0];
        }
    }
    return cnt-1; // -INF
}

inline int get_l(int x){
    int u = tr[x].s[0];
    if(u == 0) return -1;
    while(tr[u].s[1]) u = tr[u].s[1];
    return u;
}


inline int get_r(int x){
    int u = tr[x].s[1];
    if(u == 0) return -1;
    while(tr[u].s[0]) u = tr[u].s[0];
    return u;
}

inline void update(int x,int v,int tr_id){
    int u = root[tr_id];
    while(u){
        if(tr[u].v == x) break;
        if(tr[u].v < x) u = tr[u].s[1];
        else u = tr[u].s[0];
    }
    splay(u, 0, tr_id);
    int l = get_l(u);
    int r = get_r(u);
    
    splay(l, 0, tr_id);
    splay(r, l, tr_id);
    tr[r].s[0] = 0;
    
    push_up(r);
    push_up(l);
    
    insert(v,tr_id);
}


int get_pre(int v,int tr_id){
    int u = root[tr_id],ans = -INF;
    while(u){
        if(tr[u].v < v){
            ans = max(ans, tr[u].v);
            u = tr[u].s[1];
        }else{
            u = tr[u].s[0];
        }
    }
    return ans;
}

int get_suc(int v,int tr_id){
    int u = root[tr_id], ans = INF;
    while(u){
        if(tr[u].v > v){
            ans = min(ans, tr[u].v);
            u = tr[u].s[0];
        }else{
            u = tr[u].s[1];
        }
    }
    return ans;
}


// seg

void build(int l,int r,int rt){
    insert(-INF, rt);
    insert(INF, rt);
    for(int i = l; i <= r; ++i){
        insert(arr[i],rt);
    }
    if(l != r){
        int mid = (l+r)>>1;
        build(l, mid, rt<<1);
        build(mid+1, r, rt<<1|1);
    }
    

}


int query_less(int l,int r,int rt,int ql,int qr,int x){
    if(ql <= l && qr >= r){
        return count_less(x, rt);
    }else{
        int mid = (l+r)>>1;
        int ans = 0;
        if(ql <= mid){
            ans += query_less(l, mid, rt<<1, ql, qr, x);
        }
        if(qr > mid){
            ans += query_less(mid+1, r, rt<<1|1, ql, qr, x);
        }
        return ans;
    }
}

int rank_k(int ql,int qr,int k){
    int l = 0,r = 1e8,ans = -1;
    while(l <= r){
        int mid = (l+r)>>1;
        if(query_less(1, n, 1, ql, qr, mid) >= k){
            r = mid-1;
        }else{
            ans = mid;
            l = mid+1;
        }
    }
    return ans;
}


void modify(int l,int r,int rt,int pos,int v){
    update(arr[pos], v, rt);
    if(l==r){
        arr[pos] = v;
    }else{
        int mid = (l+r)>>1;
        if(pos <= mid){
            modify(l, mid, rt<<1, pos, v);
        }else{
            modify(mid+1, r, rt<<1|1, pos, v);
        }
    }
}

int query_pre(int l,int r,int rt,int ql,int qr,int v){
    if(ql <= l && qr >= r){
        return get_pre(v, rt);
    }else{
        int mid = (l+r)>>1;
        int ans = -INF;
        if(ql <= mid){
            ans = max(ans,query_pre(l, mid, rt<<1, ql, qr, v));
        }
        if(qr > mid){
            ans = max(ans,query_pre(mid+1, r, rt<<1|1, ql, qr, v));
        }
        return ans;
    }
}

int query_suc(int l,int r,int rt,int ql,int qr,int v){
    if(ql <= l && qr >= r){
        return get_suc(v, rt);
    }else{
        int mid = (l+r)>>1;
        int ans = INF;
        if(ql <= mid){
            ans = min(ans,query_suc(l, mid, rt<<1, ql, qr, v));
        }
        if(qr > mid){
            ans = min(ans,query_suc(mid+1, r, rt<<1|1, ql, qr, v));
        }
        return ans;
    }
}


int main(){
    scanf("%d%d",&n,&m);
    for(int i = 1; i <= n; ++i){
        scanf("%d",&arr[i]);
    }
    build(1, n, 1);
    
    
    while(m--){
        scanf("%d",&op);
        if(op == 1){
            scanf("%d%d%d",&L,&R,&X);
            printf("%d\n",query_less(1, n, 1, L, R, X)+1);
        }else if(op == 2){
            scanf("%d%d%d",&L,&R,&X);
            printf("%d\n",rank_k(L, R, X));
        }else if(op == 3){
            scanf("%d%d",&L,&X);
            modify(1, n, 1, L, X);
        }else if(op == 4){
            scanf("%d%d%d",&L,&R,&X);
            printf("%d\n",query_pre(1, n, 1, L, R, X));
        }else{
            scanf("%d%d%d",&L,&R,&X);
            printf("%d\n",query_suc(1, n, 1, L, R, X));
        }
    }
    
    return 0;
}

posted @ 2021-02-16 19:40  popozyl  阅读(43)  评论(0编辑  收藏  举报