红黑树
#include "common.h"
typedef struct rb_node_t rb_node_t;
struct rb_node_t
{
rb_node_t *m_parent;
rb_node_t *m_left;
rb_node_t *m_right;
bool m_red;
int m_value;
};
rb_node_t *rb_node_new(rb_node_t *parent, int value)
{
rb_node_t *p = (rb_node_t *)calloc(1, sizeof(rb_node_t));
p->m_parent = parent;
p->m_red = true;
p->m_value = value;
return p;
}
void rotate_left(rb_node_t **root, rb_node_t *p)
{
rb_node_t *pp = p->m_parent;
rb_node_t *r = p->m_right;
p->m_right = r->m_left;
r->m_left = p;
if (pp == NULL)
{
*root = r;
}
else if (p == pp->m_left)
{
pp->m_left = r;
}
else
{
pp->m_right = r;
}
if (p->m_right)
{
p->m_right->m_parent = p;
}
p->m_parent = r;
r->m_parent = pp;
}
void rotate_right(rb_node_t **root, rb_node_t *p)
{
rb_node_t *pp = p->m_parent;
rb_node_t *l = p->m_left;
p->m_left = l->m_right;
l->m_right = p;
if (pp == NULL)
{
*root = l;
}
else if (p == pp->m_left)
{
pp->m_left = l;
}
else
{
pp->m_right = l;
}
if (p->m_left)
{
p->m_left->m_parent = p;
}
p->m_parent = l;
l->m_parent = pp;
}
void fix_insert(rb_node_t **root, rb_node_t *p)
{
while (p->m_red)
{
rb_node_t *pp = p->m_parent;
rb_node_t *b = (p == pp->m_left ? pp->m_right : pp->m_left);
if (b && b->m_red)
{
p->m_red = false;
b->m_red = false;
if (pp->m_parent)
{
pp->m_red = true;
p = pp->m_parent;
continue;
}
}
else if (p == pp->m_left)
{
if (p->m_right && p->m_right->m_red)
{
rotate_left(root, p);
p = p->m_parent;
}
p->m_red = false;
pp->m_red = true;
rotate_right(root, pp);
}
else
{
if (p->m_left && p->m_left->m_red)
{
rotate_right(root, p);
p = p->m_parent;
}
p->m_red = false;
pp->m_red = true;
rotate_left(root, pp);
}
break;
}
}
void fix_remove(rb_node_t **root, rb_node_t *p)
{
while (p->m_parent && p->m_red == false)
{
rb_node_t *pp = p->m_parent;
if (p == pp->m_left)
{
rb_node_t *b = pp->m_right;
if (b->m_red)
{
pp->m_red = true;
b->m_red = false;
rotate_left(root, pp);
b = pp->m_right;
}
if (b->m_left && b->m_left->m_red || b->m_right && b->m_right->m_red)
{
if (b->m_right == NULL || b->m_right->m_red == false)
{
b->m_left->m_red = false;
b->m_red = true;
rotate_right(root, b);
b = pp->m_right;
}
b->m_red = pp->m_red;
b->m_right->m_red = false;
pp->m_red = false;
rotate_left(root, pp);
}
else
{
b->m_red = true;
p = pp;
continue;
}
}
else
{
rb_node_t *b = pp->m_left;
if (b->m_red)
{
pp->m_red = true;
b->m_red = false;
rotate_right(root, pp);
b = pp->m_left;
}
if (b->m_left && b->m_left->m_red || b->m_right && b->m_right->m_red)
{
if (b->m_left == NULL || b->m_left->m_red == false)
{
b->m_right->m_red = false;
b->m_red = true;
rotate_left(root, b);
b = pp->m_left;
}
b->m_red = pp->m_red;
b->m_left->m_red = false;
pp->m_red = false;
rotate_right(root, pp);
}
else
{
b->m_red = true;
p = pp;
continue;
}
}
break;
}
p->m_red = false;
}
void rb_insert(rb_node_t **root, int value)
{
rb_node_t *c = *root;
rb_node_t *p = NULL;
while (c)
{
p = c;
c = (value <= c->m_value) ? c->m_left : c->m_right;
}
if (p == NULL)
{
*root = rb_node_new(NULL, value);
(*root)->m_red = false;
return;
}
if (value <= p->m_value)
{
p->m_left = rb_node_new(p, value);
}
else
{
p->m_right = rb_node_new(p, value);
}
fix_insert(root, p);
}
void rb_remove(rb_node_t **root, int value)
{
rb_node_t *p = *root;
while (p && p->m_value != value)
{
p = (value < p->m_value) ? p->m_left : p->m_right;
}
if (p == NULL)
{
return;
}
if (p->m_left && p->m_right)
{
rb_node_t *c = p->m_right;
while (c->m_left)
{
c = c->m_left;
}
p->m_value = c->m_value;
p = c;
}
rb_node_t *pp = p->m_parent;
rb_node_t *r = (p->m_left ? p->m_left : p->m_right);
if (r)
{
r->m_parent = pp;
r->m_red = false;
if (pp == NULL)
{
*root = r;
}
else if (p == pp->m_left)
{
pp->m_left = r;
}
else
{
pp->m_right = r;
}
}
else if (pp == NULL)
{
*root = NULL;
}
else
{
fix_remove(root, p);
if (p == pp->m_left)
{
pp->m_left = NULL;
}
else
{
pp->m_right = NULL;
}
}
free(p);
}
int rb_depth(rb_node_t *p)
{
if (p == NULL)
{
return 0;
}
int a = rb_depth(p->m_left);
int b = rb_depth(p->m_right);
return 1 + (a < b ? b : a);
}
int main()
{
rb_node_t *root = NULL;
for (int i = 0; i < 1024 * 1024; i++)
{
rb_insert(&root, i);
}
printf("rb_depth = %d\n", rb_depth(root));
for (int i = 0; i < 1024 * 1024 - 1024; i++)
{
rb_remove(&root, i);
}
printf("rb_depth = %d\n", rb_depth(root));
for (int i = 0; i < 1024 * 1024 - 1024; i++)
{
rb_insert(&root, i);
}
printf("rb_depth = %d\n", rb_depth(root));
for (int i = 0; i < 1024 * 1024 - 1024; i++)
{
rb_remove(&root, i);
}
printf("rb_depth = %d\n", rb_depth(root));
}