Splay+离散化 - HDU 3436 - Queue-jumpers

Splay+离散化 - HDU 3436 - Queue-jumpers

因为太菜怕离散化写错,一开始尝试着写在线算法竟然还过了(玄学复杂度不会证明,貌似破坏了splay的期望logN)。本文先介绍离线的正解,文末附在线算法的代码。

1. 离线算法

离散化其实挺好想,数据范围N=1e8肯定不可能开一个1e8的splay。离散化之后就是常规的splay处理。这题多了一个少见的top操作,实现方法就是splay(L, 0) splay(R ,L)然后在 L 和 R之间插入新节点。

这里贴一个他人的代码,代码来源:Przz

/*
hdu 3436 splay树+离散化*
本来以为很好做的,写到中途发现10^8,GG
然后参考了下,把操作不用的区间缩点离散化处理
然后就是删除点,感觉自己开始写的太麻烦了,将要删除的点移动到根,如果没有儿子直接删掉,
否则将右树的最小点移到ch[r][1]使右树没有左子树,然后把根的左树接到右树上
hhh-2016-02-20 22:22:22
*/
 
#include <functional>
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <map>
#include <cmath>
using namespace std;
typedef long long ll;
typedef long  double ld;
#define key_value ch[ch[root][1]][0]
const int maxn = 200010;
 
int ch[maxn][2];
int pre[maxn],key[maxn],siz[maxn],num[maxn];
 
int root,tot,cnt,n,TOT;
int posi[maxn];
char qry[maxn][10];
int op[maxn];
int te[maxn];
int s[maxn],e[maxn];
 
void Treaval(int x) {
    if(x) {
        Treaval(ch[x][0]);
        printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,key = %2d   num= %2d \n",x,ch[x][0],ch[x][1],pre[x],siz[x],key[x],num[x]);
        Treaval(ch[x][1]);
    }
}
void debug() {printf("%d\n",root);Treaval(root);}
 
void push_up(int r)
{
    int lson = ch[r][0],rson = ch[r][1];
    siz[r] = siz[lson] + siz[rson] + num[r];
}
 
void push_down(int r)
{
 
}
 
void inOrder(int r)
{
    if(!r)return;
    inOrder(ch[r][0]);
    printf("%d ",key[r]);
    inOrder(ch[r][1]);
}
 
 
void NewNode(int &r,int far,int k)
{
    r = ++tot;
    posi[k] = r;
    key[r] = k;
    pre[r] = far;
    ch[r][0] = ch[r][1] = 0;
    siz[r] = num[r] = e[k]-s[k]+1;
}
 
 
void rotat(int x,int kind)
{
    int y = pre[x];
    push_down(y);
    push_down(x);
    ch[y][!kind] = ch[x][kind];
    pre[ch[x][kind]] = y;
    if(pre[y])
        ch[pre[y]][ch[pre[y]][1]==y] = x;
    pre[x] = pre[y];
    ch[x][kind] = y;
    pre[y] = x;
    push_up(y);
}
 
void build(int &x,int l,int r,int far)
{
    if(l > r) return ;
    int mid = (l+r) >>1;
    NewNode(x,far,mid);
    build(ch[x][0],l,mid-1,x);
    build(ch[x][1],mid+1,r,x);
    push_up(x);
}
 
void splay(int r,int goal)
{
    push_down(r);
    while(pre[r] != goal)
    {
        if(pre[pre[r]] == goal)
        {
            push_down(pre[r]);
            push_down(r);
            rotat(r,ch[pre[r]][0] == r);
        }
        else
        {
            push_down(pre[pre[r]]);
            push_down(pre[r]);
            push_down(r);
            int y = pre[r];
            int kind = ch[pre[y]][0] == y;
            if(ch[y][kind] == r)
            {
                rotat(r,!kind);
                rotat(r,kind);
            }
            else
            {
                rotat(y,kind);
                rotat(r,kind);
            }
        }
    }
    push_up(r);
    if(goal == 0)
        root = r;
}
 
int Bin(int x)
{
    int l = 0,r = TOT-1;
    while(l<=r)
    {
        int mid=(l+r)>>1;
        if(s[mid]<=x&&e[mid]>=x)
            return mid;
        if(e[mid]<x)
            l=mid+1;
        else
            r=mid-1;
    }
}
 
int get_min(int r)
{
    push_down(r);
    while(ch[r][0])
    {
        r = ch[r][0];
        push_down(r);
    }
    return r;
}
 
int get_kth(int r,int k)
{
    int t = siz[ch[r][0]];
    if(k<=t)
        return get_kth(ch[r][0],k);
    else if(k<=t+num[r])
        return s[key[r]]+(k-t)-1;
    else
        return get_kth(ch[r][1],k-t-num[r]);
}
 
void delet()
{
    if(ch[root][0] == 0 || ch[root][1] == 0)
    {
        root = ch[root][0] + ch[root][1];
        pre[root] = 0;
        return;
    }
    int k = get_min(ch[root][1]);
    splay(k,root);
    ch[ch[root][1]][0] = ch[root][0];
    root = ch[root][1];
    pre[ch[root][0]] = root;
    pre[root] = 0;
    push_up(root);
}
 
 
int top(int t)
{
    int r = Bin(t);
    r = posi[r];
    splay(r,0);
    delet();
    splay(get_min(root),0);
    ch[r][0] = 0;
    ch[r][1] = root;
    pre[root] = r;
    root = r;
    pre[root] = 0;
    push_up(root);
//    debug();
}
 
int Query(int x)
{
    int r = Bin(x);
    r = posi[r];
    splay(r,0);
    return siz[ch[r][0]]+1;
}
 
int get_rank(int x,int k)
{
    int t = siz[ch[x][0]];
    if(k <= t)
        return get_rank(ch[x][0],k);
    else
        return get_rank(ch[x][1],k-t);
}
 
 
void ini(int n)
{
    tot = root = 0;
    ch[root][0] = ch[root][1] = pre[root] = siz[root] = num[root] = 0 ;
    build(root,0,n-1,0);
 
    push_up(ch[root][1]);
    push_up(root);
    //inOrder(root);
}
 
 
int main()
{
    int q,T;
    int cas =1;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d",&n,&q) ;
        if(n == -1 && q == -1)
            break;
 
        int tcn = 0;
        printf("Case %d:\n",cas++);
        for(int i =1; i <= q; i++)
        {
            scanf("%s%d",qry[i],&op[i]);
            if(qry[i][0] == 'T' || qry[i][0] == 'Q')
                te[tcn++] = op[i];
        }
        te[tcn++] = n;
        te[tcn++] = 1;
        sort(te,te+tcn);
        TOT= 0;
        s[TOT] = te[0],e[TOT] = te[0],TOT++;
        for(int i = 1; i < tcn; i++)
        {
            if(te[i] != te[i-1] && i)
            {
                if(te[i] - te[i-1] > 1)
                {
                    s[TOT] = te[i-1]+1;
                    e[TOT] = te[i]-1;
                    TOT++;
                }
                s[TOT] = te[i];
                e[TOT] = te[i];
                TOT++;
            }
        }
        ini(TOT);
        //debug();
        for(int i = 1; i <= q; i++)
        {
            if(qry[i][0]=='T')
                top(op[i]);
            else if(qry[i][0]=='Q')
                printf("%d\n", Query(op[i]));
            else
                printf("%d\n",get_kth(root,op[i]));
        }
        //debug();
    }
    return 0;
}

2. 在线算法(不能证明复杂度)

2.1 节点定义

还是我们提到的问题,本题数据范围N=1e8,不可能为每个点单独开一个节点存。考虑用一个节点存一个线短。只有需要访问某一个线段的子线段时,才将这个线段拆分掉。

struct Node{
    int l, r; // 所代表的区间范围
    int p, size, s[2];
    
    void clear(){
        l = r = p = size = s[0] = s[1] = 0;
    }
    
    void init(int _l, int _r, int _p){
        s[0] = s[1] = 0;
        l = _l;
        r = _r;
        p = _p;
        size = _r-_l+1;
    }
}tr[M];

我们将线段的左右关系当作Splay插入时的大小关系

int insert(int l, int r){ 
    int u = rt, p = 0;
    while(u){
        p = u;
        if(l > tr[u].r){
            u = tr[u].s[1];
        }else{
            u = tr[u].s[0];
        }
    }

    u = ++idx;
    tr[u].init(l, r, p);
    if(p){
        if(l > tr[p].r){
            tr[p].s[1] = u;
        }else{
            tr[p].s[0] = u;
        }
    }
    splay(u, 0);

    return u;
}

初始化时,插入左右哨兵和整个区间

int L = insert(-INF, -INF);
int R = insert(INF, INF);
int u = insert(1, n);

2.2 拆分区间

假设我现在需要进行top(x)操作,这个操作实际上需要划分出[x, x]这段区间。这个区间可能已经在之前的操作中被划分出来了,也可能没有被划分出来。

我们可以用一个map来维护目前已经划分出的区间

map<int, int> dict;
// key: 目前已经划分的各段区间的右端点
// value: 该区间的splay下标

接下来就可以进行区间的拆分了

void split_node(int pos, int l, int r){ 
  // 将某段区间[L,R] 分裂出 [l,r] 和 剩余部分
    splay(pos, 0); // 这样保证了该节点的size区间端点更新不会影响其祖先
    if(l == tr[pos].l){
        int tmp = tr[pos].r;
        tr[pos].r = r;
        push_up(pos);
        
        // 插入 [r+1, R]
        int u = add_suc(r+1, tmp, pos);
        // [L,R] -> [L,r] [r+1,R]
        dict.find(tmp)->second = u; // R 存到u中
        dict.insert(make_pair(r, pos)); // r 存到pos中
    }else if(r == tr[pos].r){
        int tmp = tr[pos].l;
        tr[pos].l = l;
        push_up(pos);
        
        int u = add_pre(tmp, l-1, pos);
        // [L,R] -> [L,l-1] [l,R]
        dict.insert(make_pair(l-1, u));
    }else{
        // [L,R] -> [L, l-1], [l,r] , [r+1, R]
        int tmpl = tr[pos].l, tmpr = tr[pos].r;
        tr[pos].l = l;
        tr[pos].r = r;
        push_up(pos);
        
        int u = add_pre(tmpl, l-1, pos);
        int v = add_suc(r+1, tmpr, pos);

        dict.find(tmpr)->second = v;
        dict.insert(make_pair(l-1, u));
        dict.insert(make_pair(r, pos));
    }
    
}

其中,add_pre为将一个新节点插入到另一个节点的前驱,add_suc为将一个新节点插入到另一个节点的后继。这个操作不是常规的splay插入操作,因此无法保证splay仍能保持期望log(N)的复杂度(所以我说这个做法比较玄学)。

void push_up(int x){
    tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].r - tr[x].l + 1;
}

void push_up_to_root(int x){
    while(x){
        push_up(x);
        x = tr[x].p;
    }
}

int add_pre(int l, int r, int p){
    int u = ++idx;
    tr[u].init(l, r, p);
    tr[u].s[0] = tr[p].s[0];
    tr[tr[p].s[0]].p = u;
    tr[p].s[0] = u;
    push_up_to_root(u);
    splay(u, 0);
    return u;
}

int add_suc(int l, int r, int p){
    int u = ++idx;
    tr[u].init(l, r, p);
    tr[u].s[1] = tr[p].s[1];
    tr[tr[p].s[1]].p = u;
    tr[p].s[1] = u;
    push_up_to_root(u);
    splay(u, 0);
    return u;
}

2.3 top(x)

top操作需要将区间[x, x]单独划分出来,把原本的[x, x]删掉,然后将新的[x, x]插入到[-INF, -INF]的后继

那么如何判断x是否被划分出来了呢?

void top(int x){
    map<int, int> :: iterator it_x = dict.find(x);
    if(it_x == dict.end()){
        map<int, int> :: iterator it_lb = dict.lower_bound(x);
        split_node(it_lb->second, x, x);
        it_x = dict.find(x);
    }else if(tr[it_x->second].l != tr[it_x->second].r){
        split_node(it_x->second, x, x);
        it_x = dict.find(x);
    }
    
    int u = it_x->second;
    
    int xl = get_pre(u);
    int xr = get_suc(u);
    splay(xl, 0);
    splay(xr, xl);
    tr[xr].s[0] = 0;
    
    push_up(xr);
    push_up(xl);
    
    tr[u].l = -123456;
    // 更新 x 的位置
    int v = insert(-INF+1,-INF+1);
    splay(L, 0);
    splay(v, L);
    splay(R, v);
    tr[v].l = tr[v].r = x;
    it_x->second = v;
}

2.4 query(x)

这个操作很容易实现,只需要先用map查询节点的splay下标,然后将其转至根,统计左儿子大小即可。注意哨兵[-INF, -INF]会占据一位。

int query(int x){
    map<int, int> :: iterator it_x = dict.find(x);
    if(it_x == dict.end()){
        map<int, int> :: iterator it_lb = dict.lower_bound(x);
        split_node(it_lb->second, x, x);
        it_x = dict.find(x);
    }else if(tr[it_x->second].l != tr[it_x->second].r){
        split_node(it_x->second, x, x);
        it_x = dict.find(x);
    }
    
    int u = it_x->second;
    splay(u, 0);
    return tr[tr[u].s[0]].size;
}

2.5 rank_x(x)

这个操作也很简单,与常规的getk(k)稍微有一点区别。

int rank_x(int x){
    ++x;
    int u = rt;
    while(u){
        if(tr[tr[u].s[0]].size >= x){
            u = tr[u].s[0];
        }else if(tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1 >= x){
            x -= tr[tr[u].s[0]].size;
            return tr[u].l + x - 1;
        }else{
            x -= tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1;
            u = tr[u].s[1];
        }
    }
    return -1;
}

2.6 完整代码

#include <bits/stdc++.h>
using namespace std;

const int M = 10+1;
const int INF = 0x3fffffff;

int T;
int L, R;
int n, m, num;
char op[10];

map<int,int> dict;

// [L,R]段所存的下标(这个L不需要存储)
// 当前所有存在的区间中,右端点>=L的最小值

int rt, idx;

struct Node{
    int l, r; // 所代表的区间范围
    int p, size, s[2];
    
    void clear(){
        l = r = p = size = s[0] = s[1] = 0;
    }
    
    void init(int _l, int _r, int _p){
        s[0] = s[1] = 0;
        l = _l;
        r = _r;
        p = _p;
        size = _r-_l+1;
    }
}tr[M];

int ws(int x){
    return tr[tr[x].p].s[1] == x;
}

void push_up(int x){
    tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].r - tr[x].l + 1;
}

void push_up_to_root(int x){
    while(x){
        push_up(x);
        x = tr[x].p;
    }
}

void rotate(int x){
    int y = tr[x].p;
    int z = tr[y].p;
    int k = ws(x);
    tr[z].s[ws(y)] = x;
    tr[y].p = x;
    tr[y].s[k] = tr[x].s[k^1];
    tr[tr[x].s[k^1]].p = y;
    tr[x].p = z;
    tr[x].s[k^1] = y;
    push_up(y);
    push_up(x);
}

void splay(int x, int k){
    while(tr[x].p != k){
        int y = tr[x].p;
        int z = tr[y].p;
        if(z != k){
            if(ws(x) ^ ws(y)){
                rotate(x);
            }else{
                rotate(y);
            }
        }
        rotate(x);
    }
    if(!k) rt = x;
}

int insert(int l, int r){ // 我们默认,插入节点的时候已经将原区间划分开了
    int u = rt, p = 0;
    while(u){
        p = u;
        if(l > tr[u].r){
            u = tr[u].s[1];
        }else{
            u = tr[u].s[0];
        }
    }

    u = ++idx;
    tr[u].init(l, r, p);
    if(p){
        if(l > tr[p].r){
            tr[p].s[1] = u;
        }else{
            tr[p].s[0] = u;
        }
    }
    splay(u, 0);

    return u;
}

int add_pre(int l, int r, int p){
    int u = ++idx;
    tr[u].init(l, r, p);
    tr[u].s[0] = tr[p].s[0];
    tr[tr[p].s[0]].p = u;
    tr[p].s[0] = u;
    push_up_to_root(u);
    splay(u, 0);
    return u;
}

int add_suc(int l, int r, int p){
    int u = ++idx;
    tr[u].init(l, r, p);
    tr[u].s[1] = tr[p].s[1];
    tr[tr[p].s[1]].p = u;
    tr[p].s[1] = u;
    push_up_to_root(u);
    splay(u, 0);
    return u;
}

// ERROR:直接插入会改变某些顺序
void split_node(int pos, int l, int r){ // 将某段区间[L,R] 分裂出 [l,r] 和 剩余部分
    splay(pos, 0); // 这样保证了该节点的size区间端点更新不会影响其祖先
    if(l == tr[pos].l){
        int tmp = tr[pos].r;
        tr[pos].r = r;
        push_up(pos);
        
        // 插入 [r+1, R]
        int u = add_suc(r+1, tmp, pos);
        // [L,R] -> [L,r] [r+1,R]
        dict.find(tmp)->second = u; // R 存到u中了
        dict.insert(make_pair(r, pos)); // r 存到pos中了
    }else if(r == tr[pos].r){
        int tmp = tr[pos].l;
        tr[pos].l = l;
        push_up(pos);
        
        int u = add_pre(tmp, l-1, pos);
        // [L,R] -> [L,l-1] [l,R]
        dict.insert(make_pair(l-1, u));
    }else{
        // [L,R] -> [L, l-1], [l,r] , [r+1, R]
        int tmpl = tr[pos].l, tmpr = tr[pos].r;
        tr[pos].l = l;
        tr[pos].r = r;
        push_up(pos);
        
        int u = add_pre(tmpl, l-1, pos);
        int v = add_suc(r+1, tmpr, pos);

        dict.find(tmpr)->second = v;
        dict.insert(make_pair(l-1, u));
        dict.insert(make_pair(r, pos));
    }
    
}

int get_pre(int x){
    splay(x, 0);
    int u = tr[x].s[0];
    while(tr[u].s[1]) u = tr[u].s[1];
    return u;
}

int get_suc(int x){
    splay(x, 0);
    int u = tr[x].s[1];
    while(tr[u].s[0]) u = tr[u].s[0];
    return u;
}

void top(int x){
    map<int, int> :: iterator it_x = dict.find(x);
    if(it_x == dict.end()){
        map<int, int> :: iterator it_lb = dict.lower_bound(x);
        split_node(it_lb->second, x, x);
        it_x = dict.find(x);
    }else if(tr[it_x->second].l != tr[it_x->second].r){
        split_node(it_x->second, x, x);
        it_x = dict.find(x);
    }
    
    int u = it_x->second;
    
    int xl = get_pre(u);
    int xr = get_suc(u);
    splay(xl, 0);
    splay(xr, xl);
    tr[xr].s[0] = 0;
    
    push_up(xr);
    push_up(xl);
    
    tr[u].l = -123456;
    // 更新 x 的位置
    int v = insert(-INF+1,-INF+1);
    splay(L, 0);
    splay(v, L);
    splay(R, v);
    tr[v].l = tr[v].r = x;
    it_x->second = v;
}


int query(int x){
    map<int, int> :: iterator it_x = dict.find(x);
    if(it_x == dict.end()){
        map<int, int> :: iterator it_lb = dict.lower_bound(x);
        split_node(it_lb->second, x, x);
        it_x = dict.find(x);
    }else if(tr[it_x->second].l != tr[it_x->second].r){
        split_node(it_x->second, x, x);
        it_x = dict.find(x);
    }
    
    int u = it_x->second;
    splay(u, 0);
    return tr[tr[u].s[0]].size;
}

int rank_x(int x){
    ++x;
    int u = rt;
    while(u){
        if(tr[tr[u].s[0]].size >= x){
            u = tr[u].s[0];
        }else if(tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1 >= x){
            x -= tr[tr[u].s[0]].size;
            return tr[u].l + x - 1;
        }else{
            x -= tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1;
            u = tr[u].s[1];
        }
    }
    return -1;
}


void init(){
    dict.clear();
    for(int i = 1; i <= idx; ++i){
        tr[i].clear();
    }
    rt = 0;
    idx = 0;
}


int main(){
    scanf("%d", &T);
    for(int t = 1; t <= T; ++t){
        printf("Case %d:\n", t);
        
        init();
        
        L = insert(-INF,-INF);
        R = insert(INF, INF);
        scanf("%d%d", &n, &m);
        
        int base = insert(1, n);
        dict.insert(make_pair(n, base));
        while(m--){
            scanf("%s%d", op, &num);
            if(*op == 'T'){
                top(num);
            }else if(*op == 'R'){
                printf("%d\n", rank_x(num));
            }else{
                printf("%d\n", query(num));
            }
        }
    }
    
    
    
    return 0;
}
posted @ 2021-03-12 09:15  popozyl  阅读(119)  评论(1编辑  收藏  举报