AVL树模板
AVL树是二叉平衡树的一种。主要靠维护左右子树高度差值不超过1来保持平衡。
这边我用指针实现AVL树,经过一番调试,在洛谷上通过了普通平衡树模板题的所有数据。
一下是模板的c++代码
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
/*begin input&output accelerate*/
namespace IO {
template <typename T>inline T read()
{
char c=getchar();
T ans=0;
bool f=false;
while(c<'0'||c>'9')
{
if(c=='-')f=true;
c=getchar();
}
while('0'<=c&&c<='9')
{
ans=ans*10+c-'0';
c=getchar();
}
return f?-ans:ans;
}
}
/*end input&output accelerate*/
/*begin AVL Tree*/
namespace AVL {
/*begin basic functions*/
template <typename T>inline T abs(T x)
{
if(x>=0)return x;
return -x;
}
template <typename T>inline T max(T a,T b)
{
return a>b?a:b;
}
template <typename T>inline T min(T a,T b)
{
return a<b?a:b;
}
/*end basic functions*/
/*begin node*/
template <typename T>struct Node
{
Node<T> *lef,*rig;
int dep,num,sum;
T key;
Node():lef(NULL),rig(NULL),dep(0),sum(0),key(0),num(0) {}
Node(T _key,Node<T> *_lef=NULL,Node<T> *_rig=NULL):lef(_lef),rig(_rig),key(_key),dep(1),num(1),sum(1) {}
int fac()
{
int ldep=0,lsum=0;
int rdep=0,rsum=0;
if(lef!=NULL)ldep=lef->dep,lsum=lef->sum;
if(rig!=NULL)rdep=rig->dep,rsum=rig->sum;
dep=max(ldep,rdep)+1;
sum=lsum+rsum+num;
return rdep-ldep;
}
};
/*end node*/
template <typename T>int H(Node<T> *c)
{
if(c==NULL)return 0;
c->fac();
return c->dep;
}
/*begin avl tree body*/
template <typename T>class Tree
{
public:
Tree():root(NULL) {}
Node<T> *search(T key)
{
Node<T> *c;
c=root;
while(c!=NULL)
{
if(c->key==key)return c;
if(c->key>key)c=c->lef;
if(c->key<key)c=c->rig;
}
return NULL;
}
void insert(T key)
{
insert(root,key);
}
void remove(T key)
{
remove(root,key,1);
}
int rank(T key)
{
return rank(root,key);
}
T kth(int k)
{
return kth(root,k);
}
T prev(T key)
{
prekey=-1;
prev(root,key);
return prekey;
}
T next(T key)
{
nexkey=-1;
next(root,key);
return nexkey;
}
private:
/*begin find max*/
Node<T> *maxone(Node<T>*c)
{
if(c!=NULL)
{
while(c->rig!=NULL)
c=c->rig;
return c;
}
return NULL;
}
/*end find max*/
/*begin find min*/
Node<T> *minone(Node<T>*c)
{
if(c!=NULL)
{
while(c->lef!=NULL)
c=c->lef;
return c;
}
return NULL;
}
/*end find min*/
/*begin rotate*/
Node<T> *clockwise_rotate(Node<T>*&c)
{
Node<T> *r=c->lef;
c->lef=r->rig;
r->rig=c;
c->fac();
r->fac();
return r;
}
Node<T> *anti_clockwise_rotate(Node<T>*&c)
{
Node<T> *r=c->rig;
c->rig=r->lef;
r->lef=c;
c->fac();
r->fac();
return r;
}
/*end rotate*/
/*begin maintain*/
void maintain(Node<T>*&c)
{
if(c==NULL)
return;
int fac=c->fac();
if(fac<-1)
{
//left-left child;
if(H(c->lef->lef)>H(c->lef->rig))
c=clockwise_rotate(c);
//left-right child;
else
{
c->lef=anti_clockwise_rotate(c->lef);
c=clockwise_rotate(c);
}
}
if(fac>1)
{
//right-right child;
if(H(c->rig->rig)>H(c->rig->lef))
c=anti_clockwise_rotate(c);
//right-left child;
else
{
c->rig=clockwise_rotate(c->rig);
c=anti_clockwise_rotate(c);
}
}
c->fac();
}
/*end maintain*/
/*begin insert*/
Node<T> *insert(Node<T>*&c,T key)
{
if(c==NULL)
c=new Node<T>(key);
else
{
if(key==c->key)c->num++;
//key goes into the left child;
if(key<c->key)
c->lef=insert(c->lef,key);
//key goes into the right child;
if(key>c->key)
c->rig=insert(c->rig,key);
maintain(c);
}
return c;
}
/*end insert*/
/*begin remove*/
void remove(Node<T>*&c,T key,int s)
{
if(c==NULL)return;
if(c->key>key)
{
remove(c->lef,key,s);
maintain(c);
return;
}
if(c->key<key)
{
remove(c->rig,key,s);
maintain(c);
return;
}
c->num-=s;
if(c->num==0)
{
if(c->lef!=NULL&&c->rig!=NULL)
{
if(H(c->lef)>H(c->rig))
{
Node<T>*temp=maxone(c->lef);
c->key=temp->key;
c->num=temp->num;
remove(c->lef,temp->key,temp->num);
}
else
{
Node<T>*temp=minone(c->rig);
c->key=temp->key;
c->num=temp->num;
remove(c->rig,temp->key,temp->num);
}
}
else
{
Node<T>*temp=c;
if(c->lef!=NULL)
c=c->lef;
else
c=c->rig;
delete temp;
}
}
maintain(c);
}
/*end remove*/
/*begin rank*/
int rank(Node<T>*c,T key)
{
if(c==NULL)
return 0;
if(c->key==key)
return (c->lef!=NULL)?(c->lef->sum)+1:1;
if(c->key>key)
return rank(c->lef,key);
return (c->lef!=NULL)?(c->lef->sum+c->num+rank(c->rig,key)):(c->num+rank(c->rig,key));
}
/*end rank*/
/*begin kth*/
T kth(Node<T>*c,int k)
{
if(c->lef!=NULL)
{
if(c->lef->sum>=k)
return kth(c->lef,k);
else
k-=c->lef->sum;
}
if(k<=c->num)return c->key;
return kth(c->rig,k-c->num);
}
/*end kth*/
/*begin find previous*/
void prev(Node<T>*c,T key)
{
if(c==NULL)return;
if(c->key<key)
{
prekey=c->key;
prev(c->rig,key);
}
else
prev(c->lef,key);
}
/*end find previous*/
/*begin find next*/
void next(Node<T>*c,T key)
{
if(c==NULL)return;
if(c->key>key)
{
nexkey=c->key;
next(c->lef,key);
}
else
next(c->rig,key);
}
/*end find next*/
private:
Node<T>*root;
T prekey,nexkey;
};
/*end avl tree body*/
}
/*end AVL Tree*/
int main() {
using namespace IO;
using namespace AVL;
int (*R)()=read<int>;
Tree<int> Tr;
int n=R();
while(n--)
{
int f=R();
int x=R();
switch(f)
{
case 1:
Tr.insert(x);
break;
case 2:
Tr.remove(x);
break;
case 3:
printf("%d\n",Tr.rank(x));
break;
case 4:
printf("%d\n",Tr.kth(x));
break;
case 5:
printf("%d\n",Tr.prev(x));
break;
case 6:
printf("%d\n",Tr.next(x));
break;
default:
break;
}
}
// Tr.main();
return 0;
}