1 ////二叉查找树,为了实现方便,给每个节点添加了一个指向父节点的指针
  2 #include<iostream>
  3 #include<vector>
  4 #include<ctime>
  5 #include<cstdlib>
  6 
  7 using namespace std;
  8 
  9 template<class T>
 10 class BinarySearchTree
 11 {
 12     private:
 13         struct Node
 14         {
 15             T data;
 16             int deep;
 17             Node *left;
 18             Node *right;
 19             Node *prev;
 20             Node(T val,int deep)
 21             {
 22                 data = val;
 23                 deep = 0;
 24                 left = NULL;
 25                 right = NULL;
 26                 prev = NULL;
 27             }
 28 
 29             private:
 30             Node()
 31             {
 32             }
 33         };
 34         Node *root;
 35         int size;
 36 
 37     public:
 38         BinarySearchTree()
 39         {
 40             root = NULL;
 41             size = 0;
 42         }
 43         ~BinarySearchTree()
 44         {
 45             clear(root);
 46             root = NULL;
 47             size = 0;
 48         }
 49         T min(Node *node) const
 50         {
 51             if(node->left == NULL)
 52                 return node->data;
 53             else
 54                 return min(node->left);
 55         }
 56         T max(Node *node) const
 57         {
 58             if(node->right == NULL)
 59                 return node->data;
 60             else
 61                 return max(node->right);
 62         }
 63 
 64         Node *insert(Node *& node,T val)
 65         {
 66             if(size == 0 && node == NULL)
 67             {
 68                 root = new Node(val,0);
 69                 root->prev = NULL;
 70                 size++;
 71                 return root;
 72             }
 73             if(size != 0 && node == NULL)
 74             {
 75                 cout<<"ERROR\n";
 76                 return NULL;
 77             }
 78             if(val > node->data)
 79             {
 80                 if(node->right != NULL)
 81                     return insert(node->right,val);
 82                 else
 83                 {
 84                     Node *tmp = new Node(val,node->deep+1);
 85                     tmp->prev = node;
 86                     node->right = tmp;
 87                     size++;
 88                     return tmp;
 89                 }    
 90             }
 91             else if(val < node->data)
 92             {
 93                 if(node->left != NULL)
 94                     return insert(node->left,val);
 95                 else
 96                 {
 97                     Node *tmp = new Node(val,node->deep+1);
 98                     tmp->prev = node;
 99                     node->left = tmp;
100                     size ++;
101                     return tmp;
102                 }
103             }
104             else if(val == node->data)
105             {
106             }
107         }
108 
109         bool contain(Node *node,T val) const
110         {
111             if(node == NULL)
112                 return false;
113 
114             if(val > node->data)
115                 return contain(node->right,val);
116             else if(val < node->data)
117                 return contain(node->left,val);
118             else
119                 return true;
120         }
121         void removeNode(Node *node)
122         {
123             if(node->left == NULL && node->right == NULL)
124             {
125                 if(node->prev->left == node)
126                     node->prev->left = NULL;
127                 else
128                     node->prev->right = NULL;
129 
130                 delete node;
131                 size--;
132             }
133             else if(node->left == NULL)
134             {
135                 node->right->prev = node->prev;
136                 if(node->prev->left == node)
137                     node->prev->left = node->right;
138                 else
139                     node->prev->right = node->right;
140 
141                 decDeep(node->right);
142                 delete node;
143                 size--;
144             }
145             else if(node->right == NULL)
146             {
147                 node->left->prev = node->prev;
148                 if(node->prev->left == node)
149                     node->prev->left = node->left;
150                 else
151                     node->prev->right = node->left;
152 
153                 decDeep(node->left);
154                 delete node;
155                 size--;
156             }
157             else
158             {
159                 Node *p = node->right;
160                 while(p->left != NULL)
161                 {
162                     p=p->left;
163                 }
164                 node->data = p->data;
165                 if(p->right != NULL)
166                 {
167                     p->prev->left = p->right;
168                     p->right->prev = p->prev;
169                     decDeep(p->right);
170                     delete p;
171                     size--;
172                 }
173                 else
174                 {
175                     p->prev->left = NULL;
176                     delete p;
177                     size--;
178                 }
179             }
180         }
181         void decDeep(Node *node)
182         {
183             node->deep--;
184             if(node->left != NULL)
185                 decDeep(node->left);
186             if(node->right != NULL)
187                 decDeep(node->right);
188         }
189         void remove(T val)
190         {
191             Node * p=root;
192             while(1)
193             {
194                 if(val > p->data)
195                     p = p->right;
196                 else if(val < p->data)
197                     p = p->left;
198                 else if(val == p->data)
199                 {
200                     
201                     removeNode(p);
202                     return;
203                 }
204             }
205         }
206         void clear(Node*node)
207         {
208             if(node->left != NULL)
209                 clear(node->left);
210             if(node->right != NULL)
211                 clear(node->right);
212 
213             delete node;
214             node = NULL;
215         }
216         void print(Node *node)
217         {
218             if(node == NULL)
219                 return;
220             cout<<node->data<< " ";
221             if(node->left != NULL)
222                 print(node->left);
223             if(node->right != NULL)
224                 print(node->right);
225         }
226         void insert(T val)
227         {
228             insert(root,val);
229         }
230         void print()
231         {
232             print(root);
233             cout<<"\n";
234         }
235 };
236 
237 int main()
238 {
239     BinarySearchTree<int> tree;
240     tree.insert(10);
241     tree.insert(1);
242     tree.insert(11);
243     tree.insert(9);
244     tree.insert(8);
245     tree.print();
246     cout<<"\n\n";
247     tree.remove(9);
248     tree.print();
249 
250     return 0;
251 }