kd-tree的实现
参考百度百科http://baike.baidu.com/link?url=JLBeRUhL6WLyp8R6TAFDD8swLfazjQnOaSXBY3AydkrVQG8XpCJ8EIh4bWpB02wQxxzPrK723ulRCzSKxkFLy_
下面是我的实现
// kd-tree.cpp : 定义控制台应用程序的入口点。 // #include "stdafx.h" #include<iostream> #include<vector> #include<algorithm> using namespace std; #define KeyType double class kdtree { public: struct kdnode { kdnode*lnode, *rnode, *parent; double*value; int splitdim;//该节点在哪个维度分裂 kdnode() { lnode = rnode = parent = NULL; } }; private: unsigned int B;//用于构建kdb树时指定叶子中包含的数据个数,默认为2,既包含[B/2,B)个数据 int dim;//维数 kdnode*root; private: //选择在哪个维度分裂,合理的选择分裂可以减小树的高度 int getsplitdim(vector<KeyType*>&input); //分裂数据集,left,right为分裂结果 void split_dataset(vector<KeyType*>&input, int const splitdim, vector<KeyType*>&left, vector<KeyType*>&right); void create(kdnode*&node, vector<KeyType*>&input); void goback(); double distance(KeyType*const aa, KeyType*const bb) { double dis = 0; for (int i = 0; i < dim; i++) dis += pow(double(aa[i] - bb[i]), double(2)); return sqrt(dis); } bool UDless(int const dth, KeyType* elem1, KeyType*elem2) { return elem1[dth] < elem2[dth]; } public: kdtree(int dimen = 2) { root = NULL; _ASSERTE(dimen > 1); dim = dimen; } KeyType* nearest(KeyType*const val); //void insert(); void create(KeyType**&indata, int datanums); kdnode*get_root(){ return root; } ~kdtree() { if (root == NULL) return; vector<kdnode*>aa, bb; aa.push_back(root); while (!aa.empty()) { kdnode*cc = aa.back(); bb.push_back(cc); aa.pop_back(); if (cc->lnode != NULL) aa.push_back(cc->lnode); if (cc->rnode != NULL) aa.push_back(cc->rnode); } for (int i = 0; i < bb.size(); i++) delete bb[i]; }; }; void kdtree::create(KeyType**&indata, int datanums) { for (int i = 0; i < datanums; i++) { for (int j = 0; j < dim; j++) cout << indata[i][j] << " "; cout << endl; } root = new kdnode; vector<KeyType*>input; for (int i = 0; i < datanums; i++) input.push_back(indata[i]); create(root, input); } void kdtree::create(kdnode*&node, vector<KeyType*>&input) { if (input.size() < 1) return; int splitinfo = getsplitdim(input); node->value = input[input.size() / 2]; node->splitdim = splitinfo; vector<KeyType*>left, right; //left,right为输出类型 split_dataset(input, splitinfo, left, right); if (left.size() > 0) { kdnode*lnode = new kdnode; lnode->parent = node; node->lnode = lnode; create(lnode, left); } if (right.size() > 0) { kdnode*rnode = new kdnode; rnode->parent = node; node->rnode = rnode; create(rnode, right); } } void kdtree::split_dataset(vector<KeyType*>&input, int const splitdim, vector<KeyType*>&left, vector<KeyType*>&right) { int nums = input.size(); left.assign(input.begin(), input.begin() + nums / 2);//将区间[first,last)的元素赋值到当前的vector容器中 input.erase(input.begin(), input.begin() + nums / 2 + 1);//将区间[first,last)的元素删除 right = input; } int kdtree::getsplitdim(vector<KeyType*>&input)//根据方差决定在那一个维度分裂 { double maxs = -1; int splitdim; int nums = input.size(); // 利用函数对象实现升降排序 struct CompNameEx{ CompNameEx(bool asce, int k) : asce_(asce), kk(k) {} bool operator()(KeyType*const& pl, KeyType*const& pr) { return asce_ ? pl[kk] < pr[kk] : pr[kk] < pl[kk]; // 《Eff STL》条款21: 永远让比较函数对相等的值返回false } private: bool asce_; int kk; }; for (int i = 0; i < dim; i++) { double s = 0; double mean = 0; for (int j = 0; j < nums; j++) mean += input[j][i]; mean /= double(nums); for (int j = 0; j < nums; j++) { s += pow(double(input[j][i] - mean), double(2)); } if (s > maxs) { splitdim = i; maxs = s; } } sort(input.begin(), input.end(), CompNameEx(true, splitdim)); return splitdim; } KeyType* kdtree::nearest(KeyType*const val) { if (root == NULL) return NULL; double mindis = 100000; vector<kdnode*>aa; kdnode*node = root; KeyType*tt=NULL; while (node != NULL) { aa.push_back(node); if (val[node->splitdim] > node->value[node->splitdim]) node = node->rnode; else node = node->lnode; } double dis = distance(val, aa.back()->value); if (dis < mindis) { mindis = dis; tt = aa.back()->value; } aa.pop_back(); while (!aa.empty()) { dis = distance(val, aa.back()->value); if (dis < mindis) { mindis = dis; tt = aa.back()->value; int sd = aa.back()->splitdim; if (val[sd] < aa.back()->value[sd]) { kdnode*rr = aa.back()->rnode; aa.pop_back(); if (rr) aa.push_back(rr); } else { kdnode*ll = aa.back()->lnode; aa.pop_back(); if (ll) aa.push_back(ll); } } else aa.pop_back(); } return tt; } int _tmain(int argc, _TCHAR* argv[]) { kdtree kd(2); KeyType bb[6][2] = { 2, 3, 5, 4, 9, 6, 4, 7, 8, 1, 7, 2 };// { 12, 45, 34, 12, 17, 34, 43, 889, 86, 54 }; KeyType** in = new KeyType*[6]; for (int i = 0; i < 6; i++) { for (int j = 0; j < 2; j++) cout << bb[i][j] << " "; cout << endl; } for (int i = 0; i < 6; i++) in[i] = bb[i]; kdtree::kdnode*root = kd.get_root(); kd.create(in, 6); root = kd.get_root(); KeyType hh[2] = { 2, 4.5 }; KeyType*n = kd.nearest(hh); delete in; system("pause"); return 0; }
版权声明: