平衡树学习笔记

非旋平衡树 FHQ-Treap

这里介绍的是非旋 \(Treap\),即 \(FHQ-Treap\),毕竟这个好写太多,而且支持各种操作。

\(FHQ-Treap\) 包含两个重要操作:分裂和合并。

分裂(split)

分裂指的是将一棵以 \(root\) 为根节点的树,分裂成两棵分别以 \(a,b\) 为根节点的树。其中有两种分裂方式,第一种是按照权值分裂,第二种是按照子树大小分裂。

以下是一份按照权值分裂,将根为 \(nd\) 的树分裂成根分别为 \(x,y\) 的子树,并且满足 \(x\) 的权值小于等于 \(k\)\(y\) 的权值大于 \(k\)

void split(int nd,int k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    if(tr[nd].val<=k) { //当前节点<=k,放进 x 里面
        x=nd;
        split(tr[x].r,k,tr[x].r,y);//构造x的右子树
    }
    else {//同上
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    pushup(nd);//记得最后用nd的左右儿子更新nd
}

以下是一份按照子树大小分裂,将根为 \(nd\) 的分裂成根分别为 \(x,y\) 的子树,且 \(x\) 是前 \(k\) 个树,\(y\) 是后面的树,这种分裂方式通常适用于序列操作。

void split(int nd,int k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    pushdown(nd);
    if(tr[tr[nd].l].siz>=k) {
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    else {
        x=nd;
        split(tr[x].r,k-tr[tr[nd].l].siz-1,tr[x].r,y);
    }
    pushup(nd);
}

合并

合并就是将两个分别以 \(a,b\) 为根的子树,合并成一棵树。注意,\(a,b\) 的先后顺序是不能调换的。由于 \(FHQ-Treap\) 只是不旋转的 \(Treap\),但还是给每个节点赋一个键值(随机值),在合并时通过键值的大小关系判断是将谁合并到谁的下面。

int merge(int x,int y) {
    if(!x||!y) return x+y;
    if(tr[x].key<=tr[y].key) {
        tr[x].r=merge(tr[x].r,y);
        pushup(x);
        return x;
    }
    else {
        tr[y].l=merge(x,tr[y].l);
        pushup(y);
        return y;
    }
}

有了上述两个操作,我们就可以实现几乎全部的平衡树功能。

插入

void insert(int val) {
    int a=0,b=0;
    split(root,val,a,b);
    root=merge(merge(a,NewNode(val)),b);
}

删除

void remove(int val) {
    int a,b,c,d;
    split(root,val,a,b);
    split(a,val-1,c,d);
    d=merge(tr[d].l,tr[d].r);
    root=merge(merge(c,d),b);
}

查找排名

int FindRank(int val) {
    int a,b;
    split(root,val-1,a,b);
    int ans=tr[a].siz+1;
    root=merge(a,b);
    return ans;
}

按照排名找值

int FindVal(int nd,int k) {
    if(tr[tr[nd].l].siz==k-1) return tr[nd].val;
    else if(k<=tr[tr[nd].l].siz) return FindVal(tr[nd].l,k);
    else return FindVal(tr[nd].r,k-tr[tr[nd].l].siz-1);
}

前驱,后继

int FindPre(int val) {
    int a,b;
    split(root,val-1,a,b);
    int ans=FindVal(a,tr[a].siz);
    root=merge(a,b);
    return ans;
}

int FindNxt(int val) {
    int a,b;
    split(root,val,a,b);
    int ans=FindVal(b,1);
    root=merge(a,b);
    return ans;
}

以上都是权值平衡树的基本操作,下面是一些题目:

【模板】普通平衡树

板子题。不解释。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

const int N=1e5+10;
int T;
int tot,root;
struct node {
    int l,r;
    int val,siz,key;
}tr[N];

void pushup(int nd) {
    tr[nd].siz=tr[tr[nd].l].siz+tr[tr[nd].r].siz+1;
}

int NewNode(int val) {
    tr[++tot].val=val;
    tr[tot].siz=1;
    tr[tot].key=rand();
    return tot;
}

void split(int nd,int k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    if(tr[nd].val<=k) {
        x=nd;
        split(tr[x].r,k,tr[x].r,y);
    }
    else {
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    pushup(nd);
}

int merge(int x,int y) {
    if(!x||!y) return x+y;
    if(tr[x].key<tr[y].key) {
        tr[x].r=merge(tr[x].r,y);
        pushup(x);
        return x;
    }
    else {
        tr[y].l=merge(x,tr[y].l);
        pushup(y);
        return y;
    }
}

void insert(int val) {
    int a=0,b=0;
    split(root,val,a,b);
    root=merge(merge(a,NewNode(val)),b);
}

void remove(int val) {
    int a,b,c,d;
    split(root,val,a,b);
    split(a,val-1,c,d);
    d=merge(tr[d].l,tr[d].r);
    root=merge(merge(c,d),b);
}

int FindRank(int val) {
    int a,b;
    split(root,val-1,a,b);
    int ans=tr[a].siz+1;
    root=merge(a,b);
    return ans;
}

int FindVal(int nd,int k) {
    if(tr[tr[nd].l].siz==k-1) return tr[nd].val;
    else if(k<=tr[tr[nd].l].siz) return FindVal(tr[nd].l,k);
    else return FindVal(tr[nd].r,k-tr[tr[nd].l].siz-1);
}

int FindPre(int val) {
    int a,b;
    split(root,val-1,a,b);
    int ans=FindVal(a,tr[a].siz);
    root=merge(a,b);
    return ans;
}

int FindNxt(int val) {
    int a,b;
    split(root,val,a,b);
    int ans=FindVal(b,1);
    root=merge(a,b);
    return ans;
}

int main() {
    cin>>T;
    while(T--) {
        int opt,x; cin>>opt>>x;
        if(opt==1) insert(x);
        if(opt==2) remove(x);
        if(opt==3) cout<<FindRank(x)<<endl;
        if(opt==4) cout<<FindVal(root,x)<<endl;
        if(opt==5) cout<<FindPre(x)<<endl;
        if(opt==6) cout<<FindNxt(x)<<endl;
    }
    return 0;
}

[TJOI2019] 甲苯先生的滚榜

这也是一道板子题。只要将权值用类写一下就好了。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

typedef unsigned int ui ;
ui randNum( ui& seed , ui last , const ui m){ 
    seed = seed * 17 + last ; return seed % m + 1;
}

const int N=1e6+10,M=1e5+10;
int lst=7;
int T,n,m; ui seed;
int tot,root;
struct node_val {
    int num,tm;
    friend bool operator < (node_val x,node_val y) {
        if(x.num!=y.num) return x.num<y.num;
        else return x.tm>y.tm;
    }
    friend bool operator == (node_val x,node_val y) {
        return (x.num==y.num&&x.tm==y.tm);
    }
}a[M];
struct node {
    int l,r;
    int siz,key;
    node_val val;
}tr[N+M];

int NewNode(node_val x) {
    tr[++tot].key=rand();
    tr[tot].siz=1;
    tr[tot].val.num=x.num;
    tr[tot].val.tm=x.tm;
    return tot;
}

void pushup(int nd) {
    tr[nd].siz=tr[tr[nd].l].siz+tr[tr[nd].r].siz+1;
}

void split(int nd,node_val k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    if(tr[nd].val<k||k==tr[nd].val) {
        x=nd;
        split(tr[x].r,k,tr[x].r,y);
    }
    else {
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    pushup(nd);
}

int merge(int x,int y) {
    if(!x||!y) return x+y;
    if(tr[x].key>=tr[y].key) {
        tr[x].r=merge(tr[x].r,y);
        pushup(x);
        return x;
    }
    else {
        tr[y].l=merge(x,tr[y].l);
        pushup(y);
        return y;
    }
}

void Insert(node_val x) {
    int a,b;
    split(root,x,a,b);
    root=merge(merge(a,NewNode(x)),b);
}

void Remove(node_val x) {
    int a,b,c,d;
    split(root,x,a,b);
    x.tm++;
    split(a,x,c,d);
    d=merge(tr[d].l,tr[d].r);
    root=merge(merge(c,d),b);
}

int FindRank(node_val x) {
    int a,b;
    split(root,x,a,b);
    int ans=tr[b].siz;
    root=merge(a,b);
    return ans;
}

void Solve() {
    cin>>m>>n>>seed;
    for(int i=1;i<=m;i++) Insert(a[i]);
    for(int i=1;i<=n;i++) {
        int ria=randNum(seed,lst,m);
        int rib=randNum(seed,lst,m);
        Remove(a[ria]);
        a[ria].num++; a[ria].tm+=rib;
        Insert(a[ria]);
        lst=FindRank(a[ria]);
        cout<<lst<<"\n";
    }
}

void Clear() {
    for(int i=0;i<=tot;i++) tr[i].key=tr[i].l=tr[i].r=tr[i].siz=tr[i].val.num=tr[i].val.tm=0;
    for(int i=0;i<=m;i++) a[i].num=a[i].tm=0;
    n=m=tot=root=0;
}

int main() {
    cin>>T;
    while(T--) {
        Solve();
        Clear();
    }

    return 0;
}

[NOI2004] 郁闷的出纳员

虽然评分是蓝,但是没想到确实不太好做。

首先删除权值小于 \(\min\) 的是很容易实现的,只需要分裂一下就行,那么我们就要以权值为关键字,接下来考虑如何模拟这个过程。

将每位员工的工资加上 \(k\),显然是没有办法直接实现的,因为关键字是权值,而不是位置,所以我们不能在平衡树上进行区间操作,我们考虑记录一个增量 \(\Delta x\),那么每个员工的实际工资是 \(k+\Delta x\),但是这样还是存在问题的,如果此时有一个新员工加入,那么这个员工实际上并没有经历工资增加,所以我们在其加入的时候,将 \(k\) 减去 \(\Delta x\),这样这个员工的实际工资就是 \(k-\Delta x+\Delta x=k\)。符合条件。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

const int N=1e5+10;
int n,minn,b,ans;
int tot,root;
struct node {
    int l,r;
    int siz,val,key;
}tr[N];

void pushup(int nd) {
    tr[nd].siz=tr[tr[nd].l].siz+tr[tr[nd].r].siz+1;
}

int NewNode(int val) {
    tr[++tot].siz=1;
    tr[tot].val=val;
    tr[tot].key=rand();
    return tot;
}

void split(int nd,int k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    if(tr[nd].val<=k) {
        x=nd;
        split(tr[x].r,k,tr[x].r,y);
    }
    else {
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    pushup(nd);
}

int merge(int x,int y) {
    if(!x||!y) return x+y;
    if(tr[x].key<=tr[y].key) {
        tr[x].r=merge(tr[x].r,y);
        pushup(x);
        return x;
    }
    else {
        tr[y].l=merge(x,tr[y].l);
        pushup(y);
        return y;
    }
}

void Insert(int val) {
    int a,b;
    split(root,val,a,b);
    root=merge(merge(a,NewNode(val)),b);
}

void Remove(int val) {
    int a,b;
    split(root,val-1,a,b);
    ans+=tr[a].siz;
    root=b;
}

int FindVal(int nd,int k) {
    if(tr[tr[nd].l].siz+1==k) return tr[nd].val;
    else if(tr[tr[nd].l].siz>=k) return FindVal(tr[nd].l,k);
    else return FindVal(tr[nd].r,k-tr[tr[nd].l].siz-1);
}

int main() {
    cin>>n>>minn;
    while(n--) {
        string opt; int k;
        cin>>opt>>k;
        if(opt=="I") {
            if(k<minn) continue;
            else Insert(k-b);
        }
        else if(opt=="A") {
            b+=k;
        }
        else if(opt=="S") {
            b-=k;
            Remove(minn-b);
        }
        else {
            if(k>tr[root].siz) cout<<-1<<endl;
            else cout<<FindVal(root,tr[root].siz+1-k)+b<<endl;
        }
    }
    cout<<ans;

    return 0;
}

平衡树的区间操作

既然是区间操作,那么只能是以在区间中的位置为关键字,其实也不用将位置存进平衡树中,并且在分裂时多用按照子树大小分裂,我们只需要这样插入就可以保持原来的区间结构:

for(int i=1;i<=n;i++) {
    root=merge(root,NewNode(val));
    //val表示第i个位置上的相关信息
}

取出一段区间[l,r]

int a,b,c,d;
split(root,r,a,b);
split(a,l-1,c,d);
//d 树就是代表区间[l,r]的树

在某段区间干事情

这个和线段树其实是类似的,都需要 \(pushdown\) 操作,不过这里的 \(pushdown\) 需要记录更多细节。具体结合实际例子。

对某个编号为 \(k\) 的节点操作

由于我们是以在序列中的位置插进平衡树的,所以实际上没有办法直接对某个编号进行操作。

唯一的办法是:记录数组 \(pos[k]\) 表示编号为 \(k\) 的在平衡树中的节点编号。然后计算出这个节点在平衡树中的排名,由于 \(BST\) 用排名来当关键字,所以排名就是这个编号在序列中的实际位置,找出以后操作即可。

具体这样操作:首先该节点的左子树一定在其前面,如果该节点是其父亲的右儿子,那么父亲连同父亲的左儿子都比该节点小,计入答案,然后一路往上跳即可。由于期望树高是 \(\log n\) 的,所以时间复杂度约为 \(O(\log n)\).

int findpos(int nd) { //找到编号为nd的排名
    int ans=1+tr[tr[nd].l].siz,p=nd;
    while(p) {
        int fa=tr[p].fa;
        if(p==tr[fa].r) {//是父亲的右儿子,计入贡献
            ans+=tr[tr[fa].l].siz+1;
        }
        p=fa;
    }
    return ans;
}

以上就是平衡树的区间基本操作,来看几道例题吧:

【模板】文艺平衡树

感觉平衡树可以做线段树不能做的操作就只有区间反转了。。。

对每个节点打一个标记,表示这个节点是否需要翻转,翻转的本质就是将其每个儿子的子树交换即可。

最后输出时输出中序遍历就行。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

const int N=110000;
int n,m;
int root,tot;
struct node {
    int l,r;
    int siz,val,key,tag;
}tr[N];

int NewNode(int val) {
    tr[++tot].key=rand();
    tr[tot].siz=1;
    tr[tot].val=val;
    return val;
}

void pushup(int nd) {
    tr[nd].siz=tr[tr[nd].l].siz+tr[tr[nd].r].siz+1;
}

void pushdown(int nd) {
    if(!tr[nd].tag) return;
    swap(tr[nd].l,tr[nd].r);
    tr[tr[nd].l].tag^=1; tr[tr[nd].r].tag^=1;
    tr[nd].tag=0;
}

int merge(int x,int y) {
    if(!x||!y) return x+y;
    pushdown(x); pushdown(y);
    if(tr[x].key<=tr[y].key) {
        tr[x].r=merge(tr[x].r,y);
        pushup(x);
        return x;
    }
    else {
        tr[y].l=merge(x,tr[y].l);
        pushup(y);
        return y;
    }
}

void split(int nd,int k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    pushdown(nd);
    if(tr[tr[nd].l].siz>=k) {
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    else {
        x=nd;
        split(tr[x].r,k-tr[tr[nd].l].siz-1,tr[x].r,y);
    }
    pushup(nd);
}

void res(int l,int r) {
    int a,b,c,d;
    split(root,l-1,a,b);
    split(b,r-l+1,c,d);
    tr[c].tag^=1;
    root=merge(a,merge(c,d));
}

void print(int nd) { //输出中序遍历
    if(!nd) return ;
    pushdown(nd);
    print(tr[nd].l);
    cout<<tr[nd].val<<' ';
    print(tr[nd].r);
}

int main() {
    cin>>n>>m;
    for(int i=1;i<=n;i++) {
        root=merge(root,NewNode(i));
    }
    for(int i=1;i<=m;i++) {
        int l,r;
        cin>>l>>r;
        res(l,r);
    }
    print(root);
    return 0;
}

序列终结者

在区间翻转的基础上加入了,区间加,区间求最大值的操作,事实上只需要注意 \(pushdown\) 函数的编写,并且注意不要将 \(0\) 号节点(不存在的节点)计入答案就行。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N=5e4+10;
int n,m;
int tot,root;
struct node {
    int l,r;
    int siz,key,tag1;
    LL val,maxn,tag2;
}tr[N];

int NewNode(int x) {
    tr[++tot].key=rand();
    tr[tot].siz=1;
    tr[tot].maxn=x;
    tr[tot].val=x;
    return tot;
}

void pushup(int nd) {
    tr[nd].siz=tr[tr[nd].l].siz+tr[tr[nd].r].siz+1;
    tr[nd].maxn=tr[nd].val;
    if(tr[nd].l) tr[nd].maxn=max(tr[nd].maxn,tr[tr[nd].l].maxn);
    if(tr[nd].r) tr[nd].maxn=max(tr[nd].maxn,tr[tr[nd].r].maxn);
}

void Add(int nd,LL x) {
    tr[nd].maxn+=x;
    tr[nd].val+=x;
    tr[nd].tag2+=x;
}

void pushdown(int nd) {
    if(tr[nd].tag1) {
        swap(tr[nd].l,tr[nd].r);
        tr[tr[nd].l].tag1^=1; tr[tr[nd].r].tag1^=1;
        tr[nd].tag1=0;
    }
    if(tr[nd].tag2) {
        Add(tr[nd].l,tr[nd].tag2); Add(tr[nd].r,tr[nd].tag2);
        tr[nd].tag2=0;
    }
}

int merge(int x,int y) {
    if(!x||!y) return x+y;
    pushdown(x); pushdown(y);
    if(tr[x].key<=tr[y].key) {
        tr[x].r=merge(tr[x].r,y);
        pushup(x);
        return x;
    }
    else {
        tr[y].l=merge(x,tr[y].l);
        pushup(y);
        return y;
    }
}

void split(int nd,int k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    pushdown(nd);
    if(k<=tr[tr[nd].l].siz) {
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    else {
        x=nd;
        split(tr[x].r,k-tr[tr[nd].l].siz-1,tr[x].r,y);
    }
    pushup(nd);
}

void Res(int l,int r) {
    int a,b,c,d;
    split(root,r,a,b);
    split(a,l-1,c,d);
    tr[d].tag1^=1;
    a=merge(c,d);
    root=merge(a,b);
}

void Modify(int l,int r,LL x) {
    int a,b,c,d;
    split(root,r,a,b);
    split(a,l-1,c,d);
    Add(d,x);
    a=merge(c,d);
    root=merge(a,b);
}

LL ask(int l,int r) {
    int a,b,c,d;
    split(root,r,a,b);
    split(a,l-1,c,d);
    LL ans=tr[d].maxn;
    root=merge(merge(c,d),b);
    return ans;
}

void print(int nd) {
    if(!nd) return ;
    pushdown(nd);
    print(tr[nd].l);
    cout<<tr[nd].maxn<<' ';
    print(tr[nd].r);
}

int main() {
    srand(time(0));
    cin>>n>>m;
    for(int i=1;i<=n;i++) {
        root=merge(root,NewNode(0));
    }
    for(int i=1;i<=m;i++) {
        int opt,l,r; LL x;
        cin>>opt>>l>>r;
        if(opt==1) {
            cin>>x;
            Modify(l,r,x);
        }
        else if(opt==2) {
            Res(l,r);
        }
        else cout<<ask(l,r)<<"\n";
    }

    return 0;
}

[ZJOI2006] 书架

本题的难点就是如何找出编号为 \(s\) 的在哪个位置。几个修改操作均可用 \(split\)\(merge\) 搞定,若问第 \(s\) 本书的编号,其实就是问排名对应的值。

可以利用上面的代码求出某个编号对应的位置。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

const int N=1e5;
int n,m;
int root,tot;
int pos[N];
struct node {
    int l,r;
    int siz,val,key,fa;
}tr[N];

int NewNode(int x) {
    tr[++tot].siz=1;
    tr[tot].key=rand();
    tr[tot].val=x;
    return tot;
}

void pushup(int x) {
    tr[x].siz=tr[tr[x].l].siz+tr[tr[x].r].siz+1;
    tr[tr[x].l].fa=x; tr[tr[x].r].fa=x;
}

void split(int nd,int k,int &x,int &y) {
    if(!nd) return void(x=y=0);
    if(tr[tr[nd].l].siz>=k) {
        y=nd;
        split(tr[y].l,k,x,tr[y].l);
    }
    else {
        x=nd;
        split(tr[x].r,k-tr[tr[nd].l].siz-1,tr[x].r,y);
    }
    pushup(nd);
}

int merge(int x,int y) {
    if(!x||!y) return x+y;
    if(tr[x].key>=tr[y].key) {
        tr[x].r=merge(tr[x].r,y);
        pushup(x);
        return x;
    }
    else {
        tr[y].l=merge(x,tr[y].l);
        pushup(y);
        return y;
    }
}

int findpos(int nd) { //找到编号为nd的排名
    int ans=1+tr[tr[nd].l].siz,p=nd;
    while(p) {
        pushup(p);
        int fa=tr[p].fa;
        if(p==tr[fa].r) {
            ans+=tr[tr[fa].l].siz+1;
        }
        p=fa;
    }
    return ans;
}

int findval(int nd,int k) {
    if(k==tr[tr[nd].l].siz+1) return tr[nd].val;
    else if(tr[tr[nd].l].siz>=k) return findval(tr[nd].l,k);
    else return findval(tr[nd].r,k-tr[tr[nd].l].siz-1);
}

int main() {
    cin>>n>>m;
    for(int i=1;i<=n;i++) {
        int x; cin>>x;
        pos[x]=NewNode(x);
        root=merge(root,pos[x]);
    }
    while(m--) {
        string opt; int s,t;
        int a,b,c,d;
        cin>>opt>>s;
        if(opt=="Top") {
            int id=findpos(pos[s]);
            split(root,id-1,a,b);
            split(b,1,c,d);
            // cout<<"["<<tr[c].val<<"]\n";
            root=merge(c,merge(a,d));
        }
        else if(opt=="Bottom") {
            int id=findpos(pos[s]);
            split(root,id-1,a,b);
            split(b,1,c,d);
            // cout<<"["<<id<<"]\n";
            root=merge(a,merge(d,c));
        }
        else if(opt=="Insert") {
            cin>>t;
            int id=findpos(pos[s]);
            if(t==0) continue;
            if(t==1) {//往后放一位
                split(root,id-1,a,b);
                split(b,2,c,d);
                int x,y; split(c,1,x,y);
                root=merge(a,merge(merge(y,x),d));
            }
            else {
                split(root,id-2,a,b);
                split(b,2,c,d);
                int x,y; split(c,1,x,y);
                root=merge(a,merge(merge(y,x),d));
            }
        }
        else if(opt=="Ask") {
            cout<<findpos(pos[s])-1<<endl;
        }
        else if(opt=="Query") {
            cout<<findval(root,s)<<endl;
        }
    }

    return 0;
}
posted @ 2023-08-23 23:04  2017BeiJiang  阅读(8)  评论(0编辑  收藏  举报