1 #ifndef _KD_TREE_H_
2 #define _KD_TREE_H_
3
4 #include <memory>
5 #include <vector>
6 #include <algorithm>
7 #include <iostream>
8 #include <functional>
9 #include <iomanip>
10 #include <stack>
11 #include <array>
12 #include <cfloat>
13 #include <cmath>
14
15 namespace zstd
16 {
17 struct threeD_node
18 {
19 double value[3];//value[0] = x, value[1] = y, value[2] = z
20 threeD_node()
21 {
22 value[0] = 0.0;
23 value[1] = 0.0;
24 value[2] = 0.0;
25 }
26 threeD_node(double x, double y, double z)
27 {
28 value[0] = x;
29 value[1] = y;
30 value[2] = z;
31 }
32 };
33 struct sort_for_threeD_node
34 {
35 int dimension;
36 sort_for_threeD_node(int d) :dimension(d){}
37 bool operator ()(const threeD_node& lhs, const threeD_node& rhs)
38 {
39 if (dimension == 0)
40 return lhs.value[0] < rhs.value[0];
41 else if (dimension == 1)
42 return lhs.value[1] < rhs.value[1];
43 else if (dimension == 2)
44 return lhs.value[2] < rhs.value[2];
45 else
46 std::cerr << "error in sort_for_threeD_node"<< std::endl;
47 return false;
48 }
49 };
50
51 struct kd_node
52 {
53 double value[3];//value[0] = x, value[1] = y, value[2] = z
54 int kv;//0=x, 1=y, 2=z
55 bool is_leaf;
56 kd_node *left, *right;
57 kd_node()
58 {
59 value[0] = value[1] = value[2] = 0.0;
60 kv = -1;
61 is_leaf = false;
62 left = nullptr;
63 right = nullptr;
64 }
65 kd_node(const kd_node& node)
66 {
67 value[0] = node.value[0];
68 value[1] = node.value[1];
69 value[2] = node.value[2];
70 kv = node.kv;
71 is_leaf = node.is_leaf;
72 left = node.left;
73 right = node.right;
74 }
75 kd_node& operator = (const kd_node& node)
76 {
77 value[0] = node.value[0];
78 value[1] = node.value[1];
79 value[2] = node.value[2];
80 kv = node.kv;
81 is_leaf = node.is_leaf;
82 left = node.left;
83 right = node.right;
84
85 return *this;
86 }
87 };
88 class kd_tree
89 {
90 private:
91 std::shared_ptr<kd_node> root;
92 std::vector<threeD_node>& vec_ref;
93 const int k = 3;
94 const int cspace = 4;
95 private:
96 int get_dimension(int n) const
97 {
98 return n % k;
99 }
100 void sort_by_dimension(std::vector<threeD_node>& v, int dimension, int l, int r);
101 kd_node* build_tree(int left, int right, kd_node* sp_node, int dimension);
102 void _print_tree(kd_node* sp, bool left, int space);
103
104 double distance(const kd_node& lhs, const threeD_node& rhs);
105 public:
106 explicit kd_tree(std::vector<threeD_node>&);
107 kd_tree(const kd_tree&) = delete;
108 kd_tree operator = (const kd_tree&) = delete;
109 ~kd_tree(){};
110
111 void print_tree();
112 std::vector<threeD_node> find_k_nearest(int k, const threeD_node& D);
113 };
114 void kd_tree::sort_by_dimension(std::vector<threeD_node>& v, int dimension, int l, int r)
115 {
116 sort_for_threeD_node s(dimension);
117 std::sort(v.begin()+l, v.begin()+r, s);
118 }
119 kd_tree::kd_tree(std::vector<threeD_node>& v) :vec_ref(v)
120 {
121 if (vec_ref.empty())
122 root = nullptr;
123 else
124 {
125 root = std::make_shared<kd_node>();
126 int dimension = 0;
127 sort_by_dimension(vec_ref, dimension, 0, vec_ref.size());
128 int mid = vec_ref.size() / 2;
129 root->value[0] = vec_ref[mid].value[0];
130 root->value[1] = vec_ref[mid].value[1];
131 root->value[2] = vec_ref[mid].value[2];
132 root->kv = dimension;
133 if (vec_ref.size() == 1)//root is leaf
134 {
135 root->left = nullptr;
136 root->right = nullptr;
137 root->is_leaf = true;
138 }
139 else
140 {
141 root->is_leaf = false;
142 root->left = build_tree(0, mid - 1, root->left, get_dimension(dimension + 1));
143 root->right = build_tree(mid + 1, vec_ref.size() - 1, root->right, get_dimension(dimension + 1));
144 }
145 }
146 }
147 kd_node* kd_tree::build_tree(int left, int right, kd_node* sp_node, int dimension)
148 {
149 dimension = get_dimension(dimension);
150 sort_by_dimension(vec_ref, dimension, left, right + 1);
151
152 if(left == right)//leaf
153 {
154 sp_node = new kd_node();
155 sp_node->value[0] = vec_ref[left].value[0];
156 sp_node->value[1] = vec_ref[left].value[1];
157 sp_node->value[2] = vec_ref[left].value[2];
158 sp_node->kv = dimension;
159 sp_node->is_leaf = true;
160 sp_node->left = nullptr;
161 sp_node->right = nullptr;
162
163 return sp_node;
164 }
165 else if (left < right)
166 {
167 int mid = left + (right - left) / 2;
168 sp_node = new kd_node();
169 sp_node->value[0] = vec_ref[mid].value[0];
170 sp_node->value[1] = vec_ref[mid].value[1];
171 sp_node->value[2] = vec_ref[mid].value[2];
172 sp_node->kv = dimension;
173 sp_node->is_leaf = false;
174 sp_node->left = nullptr;
175 sp_node->right = nullptr;
176
177 sp_node->left = build_tree(left, mid - 1, sp_node->left, get_dimension(dimension + 1));
178 sp_node->right = build_tree(mid + 1, right, sp_node->right, get_dimension(dimension + 1));
179
180 return sp_node;
181 }
182 return nullptr;
183 }
184 void kd_tree::_print_tree(kd_node* sp, bool left, int space)
185 {
186 if (sp != nullptr)
187 {
188 _print_tree(sp->right, false, space + cspace);
189 std::cout << std::setw(space);
190 std::cout << "(" <<
191 sp->value[0] << ", " <<
192 sp->value[1] << ", " <<
193 sp->value[2] << ")";
194 if (left)
195 std::cout << "left";
196 else
197 std::cout << "right";
198 if (sp->is_leaf)
199 std::cout << "------leaf";
200 std::cout << std::endl;
201 _print_tree(sp->left, true, space + cspace);
202 }
203 else
204 std::cout << std::endl;
205 }
206 void kd_tree::print_tree()
207 {
208 std::cout << "kd_tree : " << std::endl;
209 if (root != nullptr)
210 {
211 int space = 0;
212 _print_tree(root->right, false, space + cspace);
213 std::cout << "(" <<
214 root->value[0] << ", " <<
215 root->value[1] << ", " <<
216 root->value[2] << ")root" << std::endl;
217 _print_tree(root->left, true, space + cspace);
218 }
219 }
220 double kd_tree::distance(const kd_node& lhs, const threeD_node& rhs)
221 {
222 double v0 = lhs.value[0] - rhs.value[0];
223 double v1 = lhs.value[1] - rhs.value[1];
224 double v2 = lhs.value[2] - rhs.value[2];
225 return sqrt(v0 * v0 + v1 * v1 + v2 * v2);
226 }
227 std::vector<threeD_node> kd_tree::find_k_nearest(int ks, const threeD_node& D)
228 {
229 std::vector<threeD_node> res;
230 const kd_node *ptr_kd_node;
231 if (static_cast<std::size_t>(ks) > vec_ref.size())
232 return res;
233 std::stack<kd_node> s;
234 struct pair
235 {
236 double distance;
237 kd_node node;
238 pair() :distance(DBL_MAX), node(){ }
239 bool operator < (const pair& rhs)
240 {
241 return distance < rhs.distance;
242 }
243 };
244 std::unique_ptr<pair[]> ptr_pair(new pair[ks]);
245 //pair *ptr_pair = new pair[ks]();
246 if (!ptr_pair)
247 exit(-1);
248
249 if (!root)//the tree is empty
250 return std::vector<threeD_node>();
251 else
252 {
253 if (D.value[root->kv] < root->value[root->kv])
254 {
255 s.push(*root);
256 ptr_kd_node = root->left;
257 }
258 else
259 {
260 s.push(*root);
261 ptr_kd_node = root->right;
262 }
263 while (ptr_kd_node != nullptr)
264 {
265 if (D.value[ptr_kd_node->kv] < ptr_kd_node->value[ptr_kd_node->kv])
266 {
267 s.push(*ptr_kd_node);
268 ptr_kd_node = ptr_kd_node->left;
269 }
270 else
271 {
272 s.push(*ptr_kd_node);
273 ptr_kd_node = ptr_kd_node->right;
274 }
275 }
276
277 while (!s.empty())
278 {
279 kd_node popped_kd_node;//±£´æ×îеĴÓÕ»ÖÐpop³öµÄkd_node
280 popped_kd_node = s.top();
281 s.pop();
282 double dist = distance(popped_kd_node, D);
283 std::sort(&ptr_pair[0], &ptr_pair[ks]);
284 if (dist < ptr_pair[ks-1].distance)
285 {
286 ptr_pair[ks-1].distance = dist;
287 ptr_pair[ks-1].node = popped_kd_node;
288 }
289
290 if (abs(D.value[popped_kd_node.kv] - popped_kd_node.value[popped_kd_node.kv])
291 >= dist)//Ô²²»ºÍpopped_kd_nodeµÄÁíÒ»°ëÇøÓòÏཻ
292 continue;
293 else//Ô²ºÍpopped_kd_nodeµÄÁíÒ»°ëÇøÓòÏཻ
294 {
295 if (D.value[popped_kd_node.kv] < popped_kd_node.value[popped_kd_node.kv])//right
296 {
297 kd_node *ptr = popped_kd_node.right;
298 while (ptr != nullptr)
299 {
300 s.push(*ptr);
301 if (D.value[ptr->kv] < ptr->value[ptr->kv])
302 ptr = ptr->left;
303 else
304 ptr = ptr->right;
305 }
306 }
307 else//left
308 {
309 kd_node *ptr = popped_kd_node.left;
310 while (ptr != nullptr)
311 {
312 s.push(*ptr);
313 if (D.value[ptr->kv] < ptr->value[ptr->kv])
314 ptr = ptr->left;
315 else
316 ptr = ptr->right;
317 }
318 }
319 }
320 }//end of while
321 for(int i = 0; i != ks; ++i)
322 res.push_back(threeD_node(ptr_pair[i].node.value[0],
323 ptr_pair[i].node.value[1], ptr_pair[i].node.value[2]));
324 }//end of else
325 //delete ptr_pair;
326 return res;
327 }
328
329 }//end of namespace zstd
330
331 #endif