[Machine Learning]kNN代码实现(Kd tree)

具体描述见《统计学习方法》第三章。

  1 //
  2 //  main.cpp
  3 //  kNN
  4 //
  5 //  Created by feng on 15/10/24.
  6 //  Copyright © 2015年 ttcn. All rights reserved.
  7 //
  8 
  9 #include <iostream>
 10 #include <vector>
 11 #include <algorithm>
 12 #include <cmath>
 13 using namespace std;
 14 
 15 template<typename T>
 16 struct KdTree {
 17     // ctor
 18     KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {}
 19     
 20     // KdTree是否为空
 21     bool isEmpty() { return root.empty(); }
 22     
 23     // KdTree是否为叶子节点
 24     bool isLeaf() { return !root.empty() && !leftChild && !rightChild;}
 25     
 26     // KdTree是否为根节点
 27     bool isRoot() { return !isEmpty() && !parent;}
 28     
 29     // 判断KdTree是否为根节点的左儿子
 30     bool isLeft() { return parent->leftChild->root == root; }
 31     
 32     // 判断KdTree是否为根节点的右儿子
 33     bool isRight() { return parent->rightChild->root == root; }
 34     
 35     // 存放根节点的数据
 36     vector<T> root;
 37     
 38     // 父节点
 39     KdTree<T> *parent;
 40     
 41     // 左儿子
 42     KdTree<T> *leftChild;
 43     
 44     // 右儿子
 45     KdTree<T> *rightChild;
 46 };
 47 
 48 
 49 /**
 50  *  矩阵转置
 51  *
 52  *  @param matrix 原矩阵
 53  *
 54  *  @return 原矩阵的转置矩阵
 55  */
 56 template<typename T>
 57 vector<vector<T>> transpose(const vector<vector<T>> &matrix) {
 58     size_t rows = matrix.size();
 59     size_t cols = matrix[0].size();
 60     vector<vector<T>> trans(cols, vector<T>(rows, 0));
 61     for (size_t i = 0; i < cols; ++i) {
 62         for (size_t j = 0; j < rows; ++j) {
 63             trans[i][j] = matrix[j][i];
 64         }
 65     }
 66     
 67     return trans;
 68 }
 69 
 70 /**
 71  *  找中位数
 72  *
 73  *  @param vec 数组
 74  *
 75  *  @return 数组中的中位数
 76  */
 77 template<typename T>
 78 T findMiddleValue(vector<T> vec) {
 79     sort(vec.begin(), vec.end());
 80     size_t pos = vec.size() / 2;
 81     return vec[pos];
 82 }
 83 
 84 /**
 85  *  递归构造KdTree
 86  *
 87  *  @param tree  KdTree根节点
 88  *  @param data  数据矩阵
 89  *  @param depth 当前节点深度
 90  *
 91  *  @return void
 92  */
 93 template<typename T>
 94 void buildKdTree(KdTree<T> *tree, vector<vector<T>> &data, size_t depth) {
 95     // 输入数据个数
 96     size_t samplesNum = data.size();
 97     
 98     if (samplesNum == 0) {
 99         return;
100     }
101     
102     if (samplesNum == 1) {
103         tree->root = data[0];
104         return;
105     }
106     
107     // 每一个输入数据的维度,属性个数
108     size_t k = data[0].size();
109     vector<vector<T>> transData = transpose(data);
110     
111     // 找到当前切分点
112     size_t splitAttributeIndex = depth % k;
113     vector<T> splitAttributes = transData[splitAttributeIndex];
114     T splitValue = findMiddleValue(splitAttributes);
115     
116     vector<vector<T>> leftSubSet;
117     vector<vector<T>> rightSubset;
118     
119     for (size_t i = 0; i < samplesNum; ++i) {
120         if (splitAttributes[i] == splitValue && tree->isEmpty()) {
121             tree->root = data[i];
122         } else if (splitAttributes[i] < splitValue) {
123             leftSubSet.push_back(data[i]);
124         } else {
125             rightSubset.push_back(data[i]);
126         }
127     }
128     
129     tree->leftChild = new KdTree<T>;
130     tree->leftChild->parent = tree;
131     tree->rightChild = new KdTree<T>;
132     tree->rightChild->parent = tree;
133     buildKdTree(tree->leftChild, leftSubSet, depth + 1);
134     buildKdTree(tree->rightChild, rightSubset, depth + 1);
135 }
136 
137 /**
138  *  递归打印KdTree
139  *
140  *  @param tree  KdTree
141  *  @param depth 当前深度
142  *
143  *  @return void
144  */
145 template<typename T>
146 void printKdTree(const KdTree<T> *tree, size_t depth) {
147     for (size_t i = 0; i < depth; ++i) {
148         cout << "\t";
149     }
150     
151     for (size_t i = 0; i < tree->root.size(); ++i) {
152         cout << tree->root[i] << " ";
153     }
154     cout << endl;
155     
156     if (tree->leftChild == nullptr && tree->rightChild == nullptr) {
157         return;
158     } else {
159         if (tree->leftChild) {
160             for (int i = 0; i < depth + 1; ++i) {
161                 cout << "\t";
162             }
163             cout << "left : ";
164             printKdTree(tree->leftChild, depth + 1);
165         }
166         
167         cout << endl;
168         
169         if (tree->rightChild) {
170             for (size_t i = 0; i < depth + 1; ++i) {
171                 cout << "\t";
172             }
173             cout << "right : ";
174             printKdTree(tree->rightChild, depth + 1);
175         }
176         cout << endl;
177     }
178 }
179 
180 /**
181  *  节点之间的欧氏距离
182  *
183  *  @param p1 节点1
184  *  @param p2 节点2
185  *
186  *  @return 节点之间的欧式距离
187  */
188 template<typename T>
189 T calDistance(const vector<T> &p1, const vector<T> &p2) {
190     T res = 0;
191     for (size_t i = 0; i < p1.size(); ++i) {
192         res += pow(p1[i] - p2[i], 2);
193     }
194     
195     return res;
196 }
197 
198 /**
199  *  搜索目标节点的最近邻
200  *
201  *  @param tree KdTree
202  *  @param goal 待分类的节点
203  *
204  *  @return 最近邻节点
205  */
206 template <typename T>
207 vector<T> searchNearestNeighbor(KdTree<T> *tree, const vector<T> &goal ) {
208     // 节点数属性个数
209     size_t k = tree->root.size();
210     // 划分的索引
211     size_t d = 0;
212     KdTree<T> *currentTree = tree;
213     vector<T> currentNearest = currentTree->root;
214     // 找到目标节点的最叶节点
215     while (!currentTree->isLeaf()) {
216         size_t index = d % k;
217         if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) {
218             currentTree = currentTree->leftChild;
219         } else {
220             currentTree = currentTree->rightChild;
221         }
222         
223         ++d;
224     }
225     currentNearest = currentTree->root;
226     T currentDistance = calDistance(goal, currentTree->root);
227     
228     KdTree<T> *searchDistrict;
229     if (currentTree->isLeft()) {
230         if (!(currentTree->parent->rightChild)) {
231             searchDistrict = currentTree;
232         } else {
233             searchDistrict = currentTree->parent->rightChild;
234         }
235     } else {
236         searchDistrict = currentTree->parent->leftChild;
237     }
238     
239     while (!(searchDistrict->parent)) {
240         T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]);
241         
242         if (districtDistance < currentDistance) {
243             T parentDistance = calDistance(goal, searchDistrict->parent->root);
244             
245             if (parentDistance < currentDistance) {
246                 currentDistance = parentDistance;
247                 currentTree = searchDistrict->parent;
248                 currentNearest = currentTree->root;
249             }
250             
251             if (!searchDistrict->isEmpty()) {
252                 T rootDistance = calDistance(goal, searchDistrict->root);
253                 if (rootDistance < currentDistance) {
254                     currentDistance = rootDistance;
255                     currentTree = searchDistrict;
256                     currentNearest = currentTree->root;
257                 }
258             }
259             
260             if (!(searchDistrict->leftChild)) {
261                 T leftDistance = calDistance(goal, searchDistrict->leftChild->root);
262                 if (leftDistance < currentDistance) {
263                     currentDistance = leftDistance;
264                     currentTree = searchDistrict;
265                     currentNearest = currentTree->root;
266                 }
267             }
268             
269             if (!(searchDistrict->rightChild)) {
270                 T rightDistance = calDistance(goal, searchDistrict->rightChild->root);
271                 if (rightDistance < currentDistance) {
272                     currentDistance = rightDistance;
273                     currentTree = searchDistrict;
274                     currentNearest = currentTree->root;
275                 }
276             }
277             
278         }
279         
280         if (!(searchDistrict->parent->parent)) {
281             searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild;
282         } else {
283             searchDistrict = searchDistrict->parent;
284         }
285         ++d;
286     }
287     
288     return currentNearest;
289 }
290 
291 int main(int argc, const char * argv[]) {
292     vector<vector<double>> trainDataSet{{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
293     KdTree<double> *kdTree = new KdTree<double>;
294     buildKdTree(kdTree, trainDataSet, 0);
295     printKdTree(kdTree, 0);
296     
297     vector<double> goal{3, 4.5};
298     vector<double> nearestNeighbor = searchNearestNeighbor(kdTree, goal);
299     
300     for (auto i : nearestNeighbor) {
301         cout << i << " ";
302     }
303     cout << endl;
304     
305     return 0;
306 }

 

posted @ 2015-10-25 15:46  skycore  阅读(234)  评论(0编辑  收藏  举报