c++实现kd树

  1 #ifndef _KD_TREE_H_
  2 #define _KD_TREE_H_
  3 
  4 #include <memory>
  5 #include <vector>
  6 #include <algorithm>
  7 #include <iostream>
  8 #include <functional>
  9 #include <iomanip>
 10 #include <stack>
 11 #include <array>
 12 #include <cfloat>
 13 #include <cmath>
 14 
 15 namespace zstd
 16 {
 17     struct threeD_node
 18     {
 19         double value[3];//value[0] = x, value[1] = y, value[2] = z
 20         threeD_node()
 21         {
 22             value[0] = 0.0;
 23             value[1] = 0.0;
 24             value[2] = 0.0;
 25         }
 26         threeD_node(double x, double y, double z)
 27         {
 28             value[0] = x;
 29             value[1] = y;
 30             value[2] = z;
 31         }
 32     };
 33     struct sort_for_threeD_node
 34     {
 35         int dimension;
 36         sort_for_threeD_node(int d) :dimension(d){}
 37         bool operator ()(const threeD_node& lhs, const threeD_node& rhs)
 38         {
 39             if (dimension == 0)
 40                 return lhs.value[0] < rhs.value[0];
 41             else if (dimension == 1)
 42                 return lhs.value[1] < rhs.value[1];
 43             else if (dimension == 2)
 44                 return lhs.value[2] < rhs.value[2];
 45             else
 46                 std::cerr << "error in sort_for_threeD_node"<< std::endl;
 47             return false;
 48         }
 49     };
 50 
 51     struct kd_node
 52     {
 53         double value[3];//value[0] = x, value[1] = y, value[2] = z
 54         int kv;//0=x, 1=y, 2=z
 55         bool is_leaf;
 56         kd_node  *left, *right;
 57         kd_node()
 58         {
 59             value[0] = value[1] = value[2] = 0.0;
 60             kv = -1;
 61             is_leaf = false;
 62             left = nullptr;
 63             right = nullptr;
 64         }
 65         kd_node(const kd_node& node)
 66         {
 67             value[0] = node.value[0];
 68             value[1] = node.value[1];
 69             value[2] = node.value[2];
 70             kv = node.kv;
 71             is_leaf = node.is_leaf;
 72             left = node.left;
 73             right = node.right;
 74         }
 75         kd_node& operator = (const kd_node& node)
 76         {
 77             value[0] = node.value[0];
 78             value[1] = node.value[1];
 79             value[2] = node.value[2];
 80             kv = node.kv;
 81             is_leaf = node.is_leaf;
 82             left = node.left;
 83             right = node.right;
 84 
 85             return *this;
 86         }
 87     };
 88     class kd_tree
 89     {
 90     private:
 91         std::shared_ptr<kd_node> root;
 92         std::vector<threeD_node>& vec_ref;
 93         const int k = 3;
 94         const int cspace = 4;
 95     private:
 96         int get_dimension(int n) const
 97         {
 98             return n % k;
 99         }
100         void sort_by_dimension(std::vector<threeD_node>& v, int dimension, int l, int r);
101         kd_node* build_tree(int left, int right, kd_node* sp_node, int dimension);
102         void _print_tree(kd_node* sp, bool left, int space);
103 
104         double distance(const kd_node& lhs, const threeD_node& rhs);
105     public:
106         explicit kd_tree(std::vector<threeD_node>&);
107         kd_tree(const kd_tree&) = delete;
108         kd_tree operator = (const kd_tree&) = delete;
109         ~kd_tree(){};
110 
111         void print_tree();
112         std::vector<threeD_node> find_k_nearest(int k, const threeD_node& D);
113     };
114     void kd_tree::sort_by_dimension(std::vector<threeD_node>& v, int dimension, int l, int r)
115     {
116         sort_for_threeD_node s(dimension);
117         std::sort(v.begin()+l, v.begin()+r, s);
118     }
119     kd_tree::kd_tree(std::vector<threeD_node>& v) :vec_ref(v)
120     {
121         if (vec_ref.empty())
122             root = nullptr;
123         else
124         {
125             root = std::make_shared<kd_node>();
126             int dimension = 0;
127             sort_by_dimension(vec_ref, dimension, 0, vec_ref.size());
128             int mid = vec_ref.size() / 2;
129             root->value[0] = vec_ref[mid].value[0];
130             root->value[1] = vec_ref[mid].value[1];
131             root->value[2] = vec_ref[mid].value[2];
132             root->kv = dimension;
133             if (vec_ref.size() == 1)//root is leaf
134             {
135                 root->left = nullptr;
136                 root->right = nullptr;
137                 root->is_leaf = true;
138             }
139             else
140             {
141                 root->is_leaf = false;
142                 root->left = build_tree(0, mid - 1, root->left, get_dimension(dimension + 1));
143                 root->right = build_tree(mid + 1, vec_ref.size() - 1, root->right, get_dimension(dimension + 1));
144             }
145         }
146     }
147     kd_node* kd_tree::build_tree(int left, int right, kd_node* sp_node, int dimension)
148     {
149         dimension = get_dimension(dimension);
150         sort_by_dimension(vec_ref, dimension, left, right + 1);
151 
152         if(left == right)//leaf
153         {
154             sp_node = new kd_node();
155             sp_node->value[0] = vec_ref[left].value[0];
156             sp_node->value[1] = vec_ref[left].value[1];
157             sp_node->value[2] = vec_ref[left].value[2];
158             sp_node->kv = dimension;
159             sp_node->is_leaf = true;
160             sp_node->left = nullptr;
161             sp_node->right = nullptr;
162 
163             return sp_node;
164         }
165         else if (left < right)
166         {
167             int mid = left + (right - left) / 2;
168             sp_node = new kd_node();
169             sp_node->value[0] = vec_ref[mid].value[0];
170             sp_node->value[1] = vec_ref[mid].value[1];
171             sp_node->value[2] = vec_ref[mid].value[2];
172             sp_node->kv = dimension;
173             sp_node->is_leaf = false;
174             sp_node->left = nullptr;
175             sp_node->right = nullptr;
176 
177             sp_node->left = build_tree(left, mid - 1, sp_node->left, get_dimension(dimension + 1));
178             sp_node->right = build_tree(mid + 1, right, sp_node->right, get_dimension(dimension + 1));
179 
180             return sp_node;
181         }
182         return nullptr;
183     }
184     void kd_tree::_print_tree(kd_node* sp, bool left, int space)
185     {
186         if (sp != nullptr)
187         {
188             _print_tree(sp->right, false, space + cspace);
189             std::cout << std::setw(space);
190             std::cout << "(" <<
191                 sp->value[0] << ", " <<
192                 sp->value[1] << ", " <<
193                 sp->value[2] << ")";
194             if (left)
195                 std::cout << "left";
196             else
197                 std::cout << "right";
198             if (sp->is_leaf)
199                 std::cout << "------leaf";
200             std::cout << std::endl;
201             _print_tree(sp->left, true, space + cspace);
202         }
203         else
204             std::cout << std::endl;
205     }
206     void kd_tree::print_tree()
207     {
208         std::cout << "kd_tree : " << std::endl;
209         if (root != nullptr)
210         {
211             int space = 0;
212             _print_tree(root->right, false, space + cspace);
213             std::cout << "(" << 
214                 root->value[0] << ", " << 
215                 root->value[1] << ", " << 
216                 root->value[2] << ")root" << std::endl;
217             _print_tree(root->left, true, space + cspace);
218         }
219     }
220     double kd_tree::distance(const kd_node& lhs, const threeD_node& rhs)
221     {
222         double v0 = lhs.value[0] - rhs.value[0];
223         double v1 = lhs.value[1] - rhs.value[1];
224         double v2 = lhs.value[2] - rhs.value[2];
225         return sqrt(v0 * v0 + v1 * v1 + v2 * v2);
226     }
227     std::vector<threeD_node> kd_tree::find_k_nearest(int ks, const threeD_node& D)
228     {
229         std::vector<threeD_node> res;
230         const kd_node *ptr_kd_node;
231         if (static_cast<std::size_t>(ks) > vec_ref.size())
232             return res;
233         std::stack<kd_node> s;
234         struct pair
235         {
236             double distance;
237             kd_node node;
238             pair() :distance(DBL_MAX), node(){ }
239             bool operator < (const pair& rhs)
240             {
241                 return distance < rhs.distance;
242             }
243         };
244         std::unique_ptr<pair[]> ptr_pair(new pair[ks]);
245         //pair *ptr_pair = new pair[ks]();
246         if (!ptr_pair)
247             exit(-1);
248 
249         if (!root)//the tree is empty
250             return std::vector<threeD_node>();
251         else
252         {
253             if (D.value[root->kv] < root->value[root->kv])
254             {
255                 s.push(*root);
256                 ptr_kd_node = root->left;
257             }
258             else
259             {
260                 s.push(*root);
261                 ptr_kd_node = root->right;
262             }
263             while (ptr_kd_node != nullptr)
264             {
265                 if (D.value[ptr_kd_node->kv] < ptr_kd_node->value[ptr_kd_node->kv])
266                 {
267                     s.push(*ptr_kd_node);
268                     ptr_kd_node = ptr_kd_node->left;
269                 }
270                 else
271                 {
272                     s.push(*ptr_kd_node);
273                     ptr_kd_node = ptr_kd_node->right;
274                 }
275             }
276             
277             while (!s.empty())
278             {
279                 kd_node popped_kd_node;//±£´æ×îеĴÓÕ»ÖÐpop³öµÄkd_node
280                 popped_kd_node = s.top();
281                 s.pop();
282                 double dist = distance(popped_kd_node, D);
283                 std::sort(&ptr_pair[0], &ptr_pair[ks]);
284                 if (dist < ptr_pair[ks-1].distance)
285                 {
286                     ptr_pair[ks-1].distance = dist;
287                     ptr_pair[ks-1].node = popped_kd_node;
288                 }
289 
290                 if (abs(D.value[popped_kd_node.kv] - popped_kd_node.value[popped_kd_node.kv])
291                         >= dist)//Ô²²»ºÍpopped_kd_nodeµÄÁíÒ»°ëÇøÓòÏཻ
292                     continue;
293                 else//Ô²ºÍpopped_kd_nodeµÄÁíÒ»°ëÇøÓòÏཻ
294                 {
295                     if (D.value[popped_kd_node.kv] < popped_kd_node.value[popped_kd_node.kv])//right
296                     {
297                         kd_node *ptr = popped_kd_node.right;
298                         while (ptr != nullptr)
299                         {    
300                             s.push(*ptr);
301                             if (D.value[ptr->kv] < ptr->value[ptr->kv])
302                                 ptr = ptr->left;
303                             else
304                                 ptr = ptr->right;
305                         }
306                     }
307                     else//left
308                     {
309                         kd_node *ptr = popped_kd_node.left;
310                         while (ptr != nullptr)
311                         {
312                             s.push(*ptr);
313                             if (D.value[ptr->kv] < ptr->value[ptr->kv])
314                                 ptr = ptr->left;
315                             else
316                                 ptr = ptr->right;
317                         }
318                     }
319                 }
320             }//end of while
321             for(int i = 0; i != ks; ++i)
322                 res.push_back(threeD_node(ptr_pair[i].node.value[0], 
323                             ptr_pair[i].node.value[1], ptr_pair[i].node.value[2]));
324         }//end of else
325         //delete ptr_pair;
326         return res;
327     }
328 
329 }//end of namespace zstd
330 
331 #endif
 1 #include <string>
 2 #include <iostream>
 3 #include <new>
 4 #include <fstream>
 5 #include <vector>
 6 #include <algorithm>
 7 #include <ctime>
 8 
 9 #include "trie_tree.h"
10 #include "kd_tree.h"
11 
12 int main()
13 {
14     std::vector<zstd::threeD_node> v, res;
15     v.push_back(zstd::threeD_node(2, 3, 1));//14
16     v.push_back(zstd::threeD_node(5, 4, 7));//90
17     v.push_back(zstd::threeD_node(9, 6, 9));//198
18     v.push_back(zstd::threeD_node(4, 7, 2));//69
19     v.push_back(zstd::threeD_node(8, 1, 5));//90
20     v.push_back(zstd::threeD_node(7, 2, 0));//53
21     v.push_back(zstd::threeD_node(8, 8, 8));//192
22     v.push_back(zstd::threeD_node(1, 2, 3));//14
23     v.push_back(zstd::threeD_node(5, 2, 1));//30
24     v.push_back(zstd::threeD_node(12, 23, 0));//673
25     v.push_back(zstd::threeD_node(10, 0, 2));//104
26     std::cout << "size: " << v.size() << std::endl;
27     zstd::kd_tree tree(v);
28     tree.print_tree();
29     res = tree.find_k_nearest(11, zstd::threeD_node(0, 0, 0));
30     std::cout << "-------" << std::endl;
31     std::cout << "离点(0,0,0)最近的点依次是:" << std::endl;
32     for (auto i : res)
33     {
34         std::cout << "(" << i.value[0] << ", " << i.value[1] << ", " << i.value[2] << ")" << std::endl;
35     }
36     system("pause");
37     return 0;
38 }

posted @ 2013-12-25 18:59  老司机  阅读(1285)  评论(0编辑  收藏  举报