替罪羊树
替罪羊树是一种依靠重构操作维持平衡的重量平衡树。替罪羊树会在插入、删除操作时,检测净土的节点,若发现失衡,则将以该节点为根的子树重构。
序言:
我们知道在一棵平衡的二叉搜索树内进行查询等操作时,时间就可以稳定在\(log(n)\)但是每一次的插入和删除节点,都会使得这棵树不平衡,最会情况就是退化成一条链,显然我们不想要这种树,于是各种维护的方法出现了,有通过旋转的,有拆树在合并的,然而替罪羊树就很优美的,因为一旦发现不平衡的子树,立即拍扁重构,于是替罪羊树的核心是:暴力重建
正题:
替罪羊树的每个节点都包含什么:
\(left\),\(right\):记录该节点的左右儿子
\(x\):该节点的值
\(tot\):有多少个值为\(x\)的数
\(sz,trsz,whsz\):\(sz\)表示以该节点为根的子树内有多少个节点,\(trsz\)表示有多少个有效节点,\(whsz\)表示有多少个数(也就数子树内所有节点的\(tot\)的和)
\(fa\):该点的父亲
\(vis\):该点是否有删除标记
操作一:加点
先找到一个特殊的节点,如果那个节点的值等于要加的那个点,那么直接让那个节点的\(tot+1\)即可,如果比那个节点的值要小,就让新加的节点称为它的左儿子,大的话就是右儿子。
那么如何找到那个"特殊的节点"?假如我们以\(x\)为关键字去查找,先从根节点开始,假如\(x\)比根节点的值要小,那我就去它的左儿子哪里,不然去右儿子,直到满足以下两个条件之一:
- 找到了值为\(x\)的节点
- 不能继续往下找
那么这个所谓的特殊的节点的性质也就很显然,就是与新加入的节点值相同的点,或者新加入的节点的前驱或后继
找点:
int find(int x,int now){//now表示当前找到哪个点
if(x < tr[now].x && tr[now].left) return find(x,tr[now].left);//比当前点的值要小并且有左儿子
if(x > tr[now].x && tr[now].right) return find(x,tr[now].right);
return now;
}
加点:
void add(int x){
if(root == 0){//假如当前没有根节点,也就是当前的树是空的,那么直接让他成为根
build(x,root = New(),0);//新建节点(后面有讲)
return;
}
int p = find(x,root);//找到特殊点
if(x == tr[p].x{
tr[p].tot++;
if(tr[p].vis) tr[p].vis = 0,updata(p,1,0,1);
else updata(p,0,0,1);
}
else if(x < tr[p].x) build(x,tr[p].left = New(),p),updata(p,1,1,1);
else build(x,tr[p].right = New(),p),updata(p,1,1,1);
find_rebuild(root,x);
}
上面用到的几个函数:
新建节点:
void build(int x,int y,int fa){//初始化树上编号为y的节点,它的值为x,父亲为fa
tr[y].left = tree[y].right = 0;tr[y].fa = fa;tree[y].vis = 0;
tr[y].x = x;tr[y].tot = tr[y].sz = tr[y].trsz = tr[y].whsz = 1;
}
\(updata\)函数,更新父亲以及爷爷以及祖先们的\(sz,trsz还有whsz\):
void updata(int x,int y,int z,int k){
if(!x) return;
tr[x].trsz += y;
tr[x].sz += z;
tr[x].whsz += k;
updata(tr[x].fa,y,z,k);
}
\(New\)函数就是个内存池。
操作二:删点
删点,严格来说是删掉一个数,假如我们要删掉一个值为\(x\)的数,那就先找到值为\(x\)的节点,然后\(tot-1\)
难道就这么简单么?当然不是,假如\(tot-1\)之后变成\(0\)了怎么办?这意味着这个节点不存在了,然后我们删掉这个节点么?如果把它删了,他的左右儿子怎么办?所以我们不能动他,给它打个标记,标记它被删除了
代码:
void del(int x){
int p = find(x,root);
tr[p].tot--;
if(!tr[p].tot) tr[p].vis = true,updata(p,-1,0,-1);
else updata(p,0,0,-1);
find_rebuild(root,x);
}
我们上面的代码提到了一个函数\(find\)_\(rebuild\),每一次的加点和删点都有可能使这棵树不平衡,假如有一颗子树不平衡,我们就需要将其重建,所以,\(find\) _\(rebuild\)就是用来查找需要重建的子树。
先说一下怎么重建。
因为需要重建的子树比订书二叉搜索树,那么这棵子树的中序遍历一定是一个严格上升的序列,于是我们就先中序遍历一下,把树上的有效节点放到一个数组里里面,注意无效节点(就是被打了删除标记的节点)不要。
然后我们在把数组中的节点重建成一棵极其平衡的完全二叉树(按完全二叉树的方法来建),具体放大就是每一次选取数组中间的节点,让它成为跟,左边的当它左子树,右边的当它右子树,然后再对左右儿子进行相同的操作。
怎么找需要重建的子树:
我们每次\(add\)或\(del\)的数为\(x\),在将这个数加入到树中或从树中删除之后,加入在树中值为\(x\)的节点是\(y\),那我们考虑到其实每一次可能小重构的子树只会是以“根到\(y\)路径上的节点”为根的子树,那么我们就可以从根往\(y\)走一次,看看谁要重建就好了。不是\(y\)往根走的原因:如果根到\(y\)的路径上只有两个点\(a,b\),并且\(a\)是\(b\)的祖先,然后特别巧的是\(a,b\)都是需要重建的,那么这个时候我们只要重建祖先节点为根的子树,因为重建之后,另一个为根的子树在其内部也重建完了,如果\(y\)往根走就会出现重建两遍的情况。
判断替罪羊树是否平衡:
在替罪羊树中定义了个一个平衡因子\(\alpha\),\(\alpha\)的范围因题而异,一般在\(0.5-1.0\)之间。判断一棵子树是否平衡的方法:如果\(x\)的左(右)儿子的节点数量大于以\(x\)为根的子树的节点数量\(*\alpha\),那么以\(x\)为根的这棵子树就是不平衡的,这时就将它重建。
还有一种情况就是,打了删除标记的点多了,效率自然会变慢,所欲如果在一棵子树内有超过30%的几点被打了删除标记,就把这棵树重建。
\(find\)_\(rebuild\):
void find_rebuild(int now,int x){
if(1.0 * tr[tr[now].left].sz > 1.0 * tr[now].sz * alpha || 1.0 * tr[tr[now].right].sz > 1.0 * tr[now].sz * alpha|| 1.0 * tr[now].sz - 1.0 * tr[now].trsz >1.0 * tr[now].sz * 0.3){
rebuild(now);
return;
}
if(tr[now].x != x) find_rebuild(x < tr[now].x ? tr[now].left : tree[now].right,x);
}
\(rebuild\):
void rebuild(int x){//重建以x为根的子树
tt = 0;
dfs_rebuild(x);//进行中序遍历并将有效节点压入数组
if(x == root) root = readd(1,tt,0);//x就是根,那么root就变成重建之后的那棵树的根
//readd用来把数组里的节点重新建成一棵完全二叉树,并返回这棵树的根
else{
updata(tr[x].fa,0,-(tr[x].sz - tree[x].trsz),0);//因为拍扁重建后的树中,被打了删除标记的节点将消失,所以要将祖先们的size进行更改,也就是减去被删去的节点
if(tr[tr[x].fa].left == x) tr[tr[x].fa].left = readd(1,tt,tr[x].fa);
else tr[tr[x].fa].right = readd(1,tt,tr[x].fa);
}
}
\(readd\):
int readd(int l,int r,int fa){
if(l > r)return 0
int mid = (l + r) >> 1;//选中间的点作为根
int id = New();
tree[id].fa = fa;//更新各项
tree[id].tot = num[mid].tot;
tree[id].x = num[mid].x;
tree[id].left = readd(l,mid - 1,id);
tree[id].right = readd(mid + 1,r,id);
tree[id].whsz = tr[tr[id].left].whsz + tr[tr[id].right].whsz + num[mid].tot;
tree[id].sz = tr[id].trsz = r - l + 1;
tree[id].vis = 0;
return id;
}
中序遍历:
void dfs_rebuild(int x){
if(x == 0) return;
dfs_rebuild(tr[x].left);
if(!tr[x].vis) num[++tt].x = tr[x].x,num[tt].tot = tree[x].tot;//假如没有删除标记,就只将他的x和tot加进数组,因为其他东西都没有用
ck[++t] = x;//仓库,存下废弃的节点
dfs_rebuild(tr[x].right);
}
然后就是\(New\):
int New(){
if(t > 0) return ck[t--];//假如仓库内有点,就直接用
else return ++len;//否则再创造一个点
}
然后我们就可以进行剩下的几个基本操作了
操作三:查找\(x\)的排名:
我们只需要像\(find\)函数一样走一遍行。如果往右儿子走,就让\(ans\)加上左儿子的数的个数,再加上当前节点的\(tot\),否则就往左儿子走,走到值为\(x\)的点结束。
void findx(int x){
int now = root;
int ans = 0;
while(tr[now].x != x){
if(x < tr[now].x) now = tr[now].left;
else ans += tr[tr[now].left].whsz + tr[now].tot,now = tr[now].right;
}
ans += tr[tr[now].left].whsz;
printf("%d\n",ans + 1);
}
操作四:查找排名为\(x\)的数
类似的,我们先从根走,假如当前节点的左子树的数的个数比\(x\)要小,那么让\(x\)减掉左子树的数的个数,然后在看一下当前节点的\(tot\)是否大于\(x\),如果大于的话,答案就是这个节点了,否则让\(x\)减去它的\(tot\),然后往右儿子去,重复以上操作即可。
void findrkx(int x){
int now = root;
while(1){
if(x <= tr[tr[now].left].whsz) now = tr[now].left;
else{
x -= tr[tr[now].left].whsz;
if(x <= tr[now].tot){
printf("%d\n",tr[now].x);
return;
}
x -= tr[now].tot;
now = tr[now].right;
}
}
}
要注意!这两个函数里用的都是\(whsz\)
操作五:查找\(x\)的前驱
因为替罪羊树有删除标记这个东西,所以它查找前驱和后继的时候会慢一点。
具体做法:先找到值为\(x\)的节点,然后普看看有没有左儿子,如果有就将左子树遍历一遍,顺序是:右儿子->根->左儿子,找到第一个没有被删除的节点就是答案。
因为被删除的点不超过30%,所以不用担心算法会退化成\(O(n)\)
void dfs_rml(int x){
if(tr[x].right != 0) dfs_rml(tr[x].right);
if(ans) return;
if(!tr[x].vis){
printf("%d\n",tr[x].x);
ans = 1;
return;
}
if(tr[x].left != 0) dfs_rml(tr[x].left);
}
void pre(int now,int x,bool z){
if(!z){
pre(tr[now].fa,x,tr[tr[now].fa].right == now);
return;
}
if(!tr[now].vis && tr[now].x < x){
printf("%d\n",tr[now].x);
return;
}
if(tr[now].left){
ans = 0;
dfs_rml(tr[now].left);
return;
}
pre(tr[now].fa,x,tr[tr[now].fa].right == now);
}
操作六:查找\(x\)的后继
跟前驱类似
void dfs_lmr(int x){
if(tr[x].left != 0) dfs_lmr(tr[x].left);
if(ans) return;
if(!tr[x].vis){
printf("%d\n",tre[x].x);
ans = 1;
return;
}
if(tr[x].right != 0) dfs_lmr(tr[x].right);
}
void nxt(int now,int x,bool z){
if(!z){
nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
return;
}
if(!tr[now].vis && tr[now].x > x){
printf("%d\n",tr[now].x);
return;
}
if(tr[now].right){
ans = 0;
dfs_lmr(tr[now].right);
return;
}
nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
}
后记:
-
关于\(\alpha\):
\(\alpha\)的值究竟与效率的关系,当的\(\alpha\)值越小,那么替罪羊树就越容易重构,那么树也就越平衡,查询的效率也就越高,自然修改(加点和删点)的效率也就低了。所以,如果查询操作比较多的话,就可以将\(\alpha\)的值设小一点。反之,假如修改操作多,自然\(\alpha\)的值就要大一点了。
还有,\(\alpha\)不能等于\(1\) \(or\) \(0.5\),假如它等于\(0.5\),那么当一棵树被重构之后如果因为节点数问题,不能完全重构成一个完全二叉树,那么显然,对于这棵树的根,他的"左子树节点数量 - 右子树节点数量"很可能会等于\(1\),那么如果往多的那棵子树上加一个节点,那么这棵树又得重构一次,最坏情况时间会变成\(n^2\)。那么等于1...会有一棵子树的大小大于整棵树的大小咩w?
-
关于时间复杂度:
除了重构操作,其他操作的时间复杂度显然都是\(log(n)\)的,那么下面看一下重构的时间复杂度。
虽然重构一次的时间复杂度是\(O(n)\)的,但是,均摊下来其实只是\(O(logn)\)。
考虑极端情况,每次都把整棵树重构。
那么我们就需要每次都往根的一棵子树内加点,假设一开始是平衡的,那么左右子树各有50%的节点,那么要使一棵子树内含有超过75%的节点,那么这棵子树就需要在原来的基础上增加\(2\)倍的节点数。也就是说,当最差情况时,整棵替罪羊树的节点数要翻个倍,才会重构。那么最差情况时也就是在\(4,8,16,32……\)个节点时才会重构,于是重构的总的时间复杂度也就是\(O(nlogn)\)了,加上一些杂七杂八的重构,也不过就是加上一个很小的常数,可以省略不计。所以,替罪羊树的时间复杂度依然是\(O(nlogn)\)的。
完整代码
#define B cout << "BreakPoint" << endl;
#define O(x) cout << #x << " " << x << endl;
#define O_(x) cout << #x << " " << x << " ";
#define Msz(x) cout << "Sizeof " << #x << " " << sizeof(x)/1024/1024 << " MB" << endl;
#include<cstdio>
#include<cmath>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<set>
#define LL long long
const int inf = 1e9 + 9;
const int N = 1e7 + 5;
using namespace std;
inline int read() {
int s = 0,w = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-')
w = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
s = s * 10 + ch - '0';
ch = getchar();
}
return s * w;
}
struct node{
int left,right,x,tot,sz,trsz,whsz,fa;
bool vis;
} tr[N];
struct sl{
int x,tot;
}num[N];
int len,n,root,ck[N],t;
double alpha = 0.75;
void build(int x,int y,int fa){
tr[y].left = tr[y].right = 0;
tr[y].fa = fa,tr[y].vis = false;
tr[y].x = x,tr[y].tot = tr[y].sz = tr[y].trsz = tr[y].whsz = 1;
}
inline int New(){
if(t > 0) return ck[t--];
else return ++len;
}
void updata(int x,int y,int z,int k){
if(!x) return;
tr[x].trsz += y;
tr[x].sz += z;
tr[x].whsz += k;
updata(tr[x].fa,y,z,k);
}
int find(int x,int now){
if(x < tr[now].x && tr[now].left) return find(x,tr[now].left);
if(x > tr[now].x && tr[now].right) return find(x,tr[now].right);
return now;
}
int tt;
void dfs_rebuild(int x){
if(x == 0)return;
dfs_rebuild(tr[x].left);
if(!tr[x].vis) num[++tt].x = tr[x].x,num[tt].tot = tr[x].tot;
ck[++t] = x;
dfs_rebuild(tr[x].right);
}
int readd(int l,int r,int fa){
if(l > r) return 0;
int mid = (l+r)>>1;
int id = New();
tr[id].fa = fa;
tr[id].tot = num[mid].tot;
tr[id].x = num[mid].x;
tr[id].left = readd(l,mid-1,id);
tr[id].right = readd(mid+1,r,id);
tr[id].whsz = tr[tr[id].left].whsz + tr[tr[id].right].whsz + num[mid].tot;
tr[id].sz = tr[id].trsz = r - l + 1;
tr[id].vis = false;
return id;
}
void rebuild(int x){
tt = 0;
dfs_rebuild(x);
if(x == root) root = readd(1,tt,0);
else{
updata(tr[x].fa,0,-tr[x].sz + tr[x].trsz,0);
if(tr[tr[x].fa].left == x) tr[tr[x].fa].left = readd(1,tt,tr[x].fa);
else tr[tr[x].fa].right = readd(1,tt,tr[x].fa);
}
}
void find_rebuild(int now,int x){
if(1.0 * tr[tr[now].left].sz > 1.0 * tr[now].sz * alpha || 1.0 * tr[tr[now].right].sz > 1.0 * tr[now].sz * alpha|| 1.0 * tr[now].sz - 1.0 * tr[now].trsz >1.0 * tr[now].sz * 0.3){
rebuild(now);
return;
}
if(tr[now].x != x) find_rebuild(x < tr[now].x ? tr[now].left : tr[now].right,x);
}
void add(int x){
if(root == 0){
build(x,root = New(),0);
return;
}
int p = find(x,root);
if(x == tr[p].x){
tr[p].tot++;
if(tr[p].vis) tr[p].vis = 0,updata(p,1,0,1);
else updata(p,0,0,1);
}
else if(x < tr[p].x) build(x,tr[p].left = New(),p),updata(p,1,1,1);
else build(x,tr[p].right = New(),p),updata(p,1,1,1);
find_rebuild(root,x);
}
void del(int x){
int p = find(x,root);
tr[p].tot--;
if(!tr[p].tot) tr[p].vis = 1,updata(p,-1,0,-1);
else updata(p,0,0,-1);
find_rebuild(root,x);
}
void findx(int x){
int now = root;
int ans = 0;
while(tr[now].x != x){
if(x < tr[now].x) now = tr[now].left;
else ans += tr[tr[now].left].whsz + tr[now].tot,now = tr[now].right;
}
ans += tr[tr[now].left].whsz;
printf("%d\n",ans + 1);
}
void findrkx(int x){
int now = root;
while(1){
if(x <= tr[tr[now].left].whsz) now = tr[now].left;
else{
x -= tr[tr[now].left].whsz;
if(x <= tr[now].tot){
printf("%d\n",tr[now].x);
return;
}
x -= tr[now].tot;
now = tr[now].right;
}
}
}
bool ans;
void dfs_rml(int x){
if(tr[x].right != 0) dfs_rml(tr[x].right);
if(ans) return;
if(!tr[x].vis){
printf("%d\n",tr[x].x);
ans = 1;
return;
}
if(tr[x].left != 0) dfs_rml(tr[x].left);
}
void pre(int now,int x,bool z){
if(!z){
pre(tr[now].fa,x,tr[tr[now].fa].right == now);
return;
}
if(!tr[now].vis && tr[now].x < x){
printf("%d\n",tr[now].x);
return;
}
if(tr[now].left){
ans = 0;
dfs_rml(tr[now].left);
return;
}
pre(tr[now].fa,x,tr[tr[now].fa].right == now);
}
void dfs_lmr(int x){
if(tr[x].left != 0) dfs_lmr(tr[x].left);
if(ans) return;
if(!tr[x].vis){
printf("%d\n",tr[x].x);
ans = 1;
return;
}
if(tr[x].right != 0) dfs_lmr(tr[x].right);
}
void nxt(int now,int x,bool z){
if(!z){
nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
return;
}
if(!tr[now].vis && tr[now].x > x){
printf("%d\n",tr[now].x);
return;
}
if(tr[now].right){
ans = 0;
dfs_lmr(tr[now].right);
return;
}
nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
}
int main(){
n = read();
while(n--){
int id = read(),x = read();
if(id == 1) add(x);
if(id == 2) del(x);
if(id == 3) findx(x);
if(id == 4) findrkx(x);
if(id == 5) pre(find(x,root),x,1);
if(id == 6) nxt(find(x,root),x,1);
}
}