平衡树入门——替罪羊树
平衡树入门——替罪羊树
1 简介
替罪羊树是一颗重量平衡树,不需要旋转,但是非常暴力,据说常数很小,但是我写的替罪羊树跑不过 Treap ,可能常数比较大。。。
2 数据结构解析
2.1 节点结构体
struct node{
int val,l,r,cnt,size,allsize,not_dele_size;
};
node p[N];
\(val\) 是节点权值,\(l,r\) 是左右儿子,\(cnt\) 是权值相同节点个数,\(size\) 是子树大小,注意 \(size\) 并没有重复计数权值相同的节点,也就是子树节点个数,\(allsize\) 是算上了权值相同节点的大小。
那么 \(not\_dele\_size\) 是什么呢?请注意,替罪羊树的删除是惰性删除,也就是说这个节点就算已经被删除了它还是在那里,我们删除的时候只把 \(cnt-1\) ,而不管其他东西,所以 \(not\_dele\_size\) 值得就是不算被删除的节点,子树节点个数是多少。
换句话说,\(size\) 里面算上了被删除节点,而 \(allsize\) 因为统计的是 \(cnt\),被删除节点的 \(cnt\) 是 \(0\) ,所以 \(allsize\) 里面其实也没有计数被删除节点。
2.2 重构代码
没有旋转,替罪羊树是如何判断并维护整颗树是否平衡呢?我们考虑不平衡一定是一棵树有一颗子树很大,而另一颗子树很小,那么我们如何判断这件事情?我们引入平衡系数—— \(\alpha\),这个平衡系数选在 \(0.5\) 到 \(1\) 之间,通常取 \(0.7\) ,如果有一颗子树的大小占了整棵树的 \(\alpha\) 还多,我们就认为这颗树不平衡了。
如果使这棵树变得平衡呢?暴力重构整棵树就可以做。我们对这棵树进行中序排序,那么怎样建这颗树是最平衡的?我们取最中间的节点作为树根就是最平衡的,因为平衡树中序排序后的序列是有序的,所以这么做正确性显然。
在重构的时候我们顺便去掉所有已经被删除的节点。
首先是判断是否平衡的代码:
inline bool can_rest(int k){
return (p[k].cnt)&&(alpha*(dd)p[k].size<=max((dd)p[p[k].l].size,dd(p[p[k].r].size))||((dd)p[k].not_dele_size<=alpha*(dd)p[k].size));
}
请注意,因为替罪羊树是惰性删除,所以要时刻注意如何处理被删除节点,不能让被删除节点产生影响。除了判断子树大小,如果没有被删除的节点太多,也影响效率,我们也进行重构。
接下来是中序排序的代码和重构的代码:
inline void mid_travel(int &tail,int k){
if(!k) return;
mid_travel(tail,p[k].l);
if(p[k].cnt) mid_tra[tail++]=k;
mid_travel(tail,p[k].r);
}
inline int rest_build(int l,int r){
if(l>=r) return 0;
int mid=l+r>>1;
p[mid_tra[mid]].l=rest_build(l,mid);
p[mid_tra[mid]].r=rest_build(mid+1,r);
pushup(mid_tra[mid]);return mid_tra[mid];
}
需要注意的是,这里的 \(rest\_build\) 函数是把 \([l,r)\) 这段区间进行重构。最终 \(rest\_build\) 会返回整棵树的根节点。
而第 \(4\) 行我们保证了去掉被删除节点。
调用:
inline void rest(int &k){
int tail=0;
mid_travel(tail,k);
k=rest_build(0,tail);
}
调用完后,整颗以 \(k\) 为根的子树被彻底重构。
2.3 新节点与合并信息
inline void pushup(int k){
p[k].size=p[p[k].l].size+p[p[k].r].size+1;
p[k].allsize=p[p[k].l].allsize+p[p[k].r].allsize+p[k].cnt;
p[k].not_dele_size=p[p[k].l].not_dele_size+p[p[k].r].not_dele_size+(p[k].cnt!=0);
}
inline int new_node(int val){
tot++;p[tot].cnt=p[tot].size=p[tot].allsize=p[tot].not_dele_size=1;
p[tot].val=val;p[tot].l=p[tot].r=0;return tot;
}
其中 \(tot\) 是节点总数,包括被删除节点。这比较显然,不作讲解。
2.4 插入
inline void insert(int &k,int val){
if(!k){
k=new_node(val);
return;
}
if(val==p[k].val) p[k].cnt++;
else if(val<p[k].val) insert(p[k].l,val);
else insert(p[k].r,val);
pushup(k);if(can_rest(k)) rest(k);
return;
}
插入比较简单,只需要在需要重构的时候重构,注意先合并再重构。
2.5 删除
inline void delete_(int &k,int val){
if(!k) return;
if(p[k].val==val){
if(p[k].cnt) p[k].cnt--;
}
else if(val<p[k].val) delete_(p[k].l,val);
else delete_(p[k].r,val);
pushup(k);if(can_rest(k)) rest(k);
return;
}
因为替罪羊树是惰性删除,所以删除也比较显然,注意不要在第 \(4\) 行后直接写return;
因为节点 \(k\) 需要合并。
2.6 查询后继排名和前驱排名
inline int upper_rank(int k,int val){
if(!k) return 1;
else if(p[k].val==val&&p[k].cnt) return p[p[k].l].allsize+1+p[k].cnt;
else if(val<p[k].val) return upper_rank(p[k].l,val);
else return p[p[k].l].allsize+p[k].cnt+upper_rank(p[k].r,val);
}
inline int lower_rank(int k,int val){
if(!k) return 0;
if(p[k].val==val&&p[k].cnt) return p[p[k].l].allsize;
else if(p[k].val<val) return p[p[k].l].allsize+p[k].cnt+lower_rank(p[k].r,val);
else return lower_rank(p[k].l,val);
}
这里的坑点比较多,但是实现比较巧妙。注意第 \(3,9\) 行不要忘记判断被删除节点,\(4,5\) ,\(10,11\) 行不能交换,这涉及到如果 p[k].val==val
并且 p[k].cnt==0
,对于查后继排名来说,它需要进入 p[k].r
,对于查前驱排名来说,它需要进入 p[k].l
。所以两个判断不能交换,其他都比较显然。
2.7 查询值
inline int getval(int k,int rank){
if(!k) return 0;
if(p[p[k].l].allsize<rank&&rank<=p[p[k].l].allsize+p[k].cnt) return p[k].val;
else if(p[p[k].l].allsize+p[k].cnt<rank) return getval(p[k].r,rank-p[p[k].l].allsize-p[k].cnt);
else return getval(p[k].l,rank);
}
这个比较显然,注意第三行已经去除了被删除节点的影响。
2.8 查询排名,找前驱后继
inline int getrank(int k,int val){
int ans=lower_rank(k,val);
return ans+1;
}
inline int getpre(int k,int val){
return getval(k,lower_rank(k,val));
}
inline int getnext(int k,int val){
return getval(k,upper_rank(k,val));
}
这个比较显然,也不作讲解。
3 总代码
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 401000
#define M number
using namespace std;
const int INF=0x3f3f3f3f;
const dd alpha=0.7;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
struct node{
int val,l,r,cnt,size,allsize,not_dele_size;
};
node p[N];
struct ScapeGoatTree{
int root,tot,mid_tra[N];
inline void pushup(int k){
p[k].size=p[p[k].l].size+p[p[k].r].size+1;
p[k].allsize=p[p[k].l].allsize+p[p[k].r].allsize+p[k].cnt;
p[k].not_dele_size=p[p[k].l].not_dele_size+p[p[k].r].not_dele_size+(p[k].cnt!=0);
}
inline int new_node(int val){
tot++;p[tot].cnt=p[tot].size=p[tot].allsize=p[tot].not_dele_size=1;
p[tot].val=val;p[tot].l=p[tot].r=0;return tot;
}
inline bool can_rest(int k){
return (p[k].cnt)&&(alpha*(dd)p[k].size<=max((dd)p[p[k].l].size,dd(p[p[k].r].size))||((dd)p[k].not_dele_size<=alpha*(dd)p[k].size));
}
inline void mid_travel(int &tail,int k){
if(!k) return;
mid_travel(tail,p[k].l);
if(p[k].cnt) mid_tra[tail++]=k;
mid_travel(tail,p[k].r);
}
inline int rest_build(int l,int r){
if(l>=r) return 0;
int mid=l+r>>1;
p[mid_tra[mid]].l=rest_build(l,mid);
p[mid_tra[mid]].r=rest_build(mid+1,r);
pushup(mid_tra[mid]);return mid_tra[mid];
}
inline void rest(int &k){
int tail=0;
mid_travel(tail,k);
k=rest_build(0,tail);
}
inline void insert(int &k,int val){
if(!k){
k=new_node(val);
return;
}
if(val==p[k].val) p[k].cnt++;
else if(val<p[k].val) insert(p[k].l,val);
else insert(p[k].r,val);
pushup(k);if(can_rest(k)) rest(k);
return;
}
inline void delete_(int &k,int val){
if(!k) return;
if(p[k].val==val){
if(p[k].cnt) p[k].cnt--;
}
else if(val<p[k].val) delete_(p[k].l,val);
else delete_(p[k].r,val);
pushup(k);if(can_rest(k)) rest(k);
return;
}
inline int upper_rank(int k,int val){
if(!k) return 1;
else if(p[k].val==val&&p[k].cnt) return p[p[k].l].allsize+1+p[k].cnt;
else if(val<p[k].val) return upper_rank(p[k].l,val);
else return p[p[k].l].allsize+p[k].cnt+upper_rank(p[k].r,val);
}
inline int lower_rank(int k,int val){
if(!k) return 0;
if(p[k].val==val&&p[k].cnt) return p[p[k].l].allsize;
else if(p[k].val<val) return p[p[k].l].allsize+p[k].cnt+lower_rank(p[k].r,val);
else return lower_rank(p[k].l,val);
}
inline int getval(int k,int rank){
if(!k) return 0;
if(p[p[k].l].allsize<rank&&rank<=p[p[k].l].allsize+p[k].cnt) return p[k].val;
else if(p[p[k].l].allsize+p[k].cnt<rank) return getval(p[k].r,rank-p[p[k].l].allsize-p[k].cnt);
else return getval(p[k].l,rank);
}
inline int getrank(int k,int val){
int ans=lower_rank(k,val);
return ans+1;
}
inline int getpre(int k,int val){
return getval(k,lower_rank(k,val));
}
inline int getnext(int k,int val){
return getval(k,upper_rank(k,val));
}
};
ScapeGoatTree sgt;
int n;
int main(){
read(n);
for(int i=1;i<=n;i++){
int op,x;read(op);read(x);
if(op==1) sgt.insert(sgt.root,x);
else if(op==2) sgt.delete_(sgt.root,x);
else if(op==3) printf("%d\n",sgt.getrank(sgt.root,x));
else if(op==4) printf("%d\n",sgt.getval(sgt.root,x));
else if(op==5) printf("%d\n",sgt.getpre(sgt.root,x));
else if(op==6) printf("%d\n",sgt.getnext(sgt.root,x));
}
return 0;
}