ML-KDTree思想、划分、实现

1.概念

        kd树是一种对k维空间中的实例进行存储以便快速检索的二叉树形结构。构造kd树相当于不断用垂直于坐标轴的超平面对k维空间切分,构成一系列k维超矩形区域。每个节点对应于k维超矩形区域。

所有非叶子节点可以视作用一个超平面把空间分区成两个半空间。节点左边的子树代表在超平面左边的点,节点右边的子树代表在超平面右边的点。

如果选择按照x轴划分,所有x值小于指定值的节点都会出现在左子树,所有x值大于指定值的节点都会出现在右子树。

 

假设二维,数据集T={(7,2),(5,4),(2,3),(4,7),(9,6),(8,1)}

2.轴的划分

1)轮流对轴进行划分,如二维,轮流对x,y划分

2)基于轴上方差最大的轴划分,这样划分区分度更大,如计算x上(7、5、2、4、9、8),y轴上(2、4、3、7、6、1)值的方差,取最大的作为划分轴。

3.生成

选定轴后,取轴的中点数字为划分点,如选定x轴:(7、5、2、4、9、8),然后中点取7,则用(7,2)点作为划分,左子树数据上x轴小于7,右子树x值大于=7,的2个数据集划分

如图,1次轴划分后

最终不断基于轴划分,然后即可产生KD树:

4.KD树的查找

KDTree通常用在KNN算法等地方,寻找某个数据点最近邻的k个点。通过构造KDTree,可以快速的查找数据点的k个最近点。

python创建: 



class KDNode(object):
    def __init__(self, value, split, left, right):
        # value=[x,y]
        self.value = value
        self.split = split
        self.right = right
        self.left = left


class KDTree(object):
    def __init__(self, data):
        # data=[[x1,y1],[x2,y2]...,]
        # 维度
        k = len(data[0])

        def CreateNode(split, data_set):
            if not data_set:
                return None
            data_set.sort(key=lambda x: x[split])
            # 整除2
            split_pos = len(data_set) // 2
            median = data_set[split_pos]
            split_next = (split + 1) % k

            return KDNode(median, split, CreateNode(split_next, data_set[: split_pos]),
                          CreateNode(split_next, data_set[split_pos + 1:]))

        self.root = CreateNode(0, data)

查找:

    def search(self, root, x, count=1):
        nearest = []
        for i in range(count):
            nearest.append([-1, None])
        self.nearest = np.array(nearest)

        def recurve(node):
            if node is not None:
                axis = node.split
                daxis = x[axis] - node.value[axis]
                if daxis < 0:
                    recurve(node.left)
                else:
                    recurve(node.right)
                dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.value)))
                for i, d in enumerate(self.nearest):
                    if d[0] < 0 or dist < d[0]:  # 如果当前nearest内i处未标记(-1),或者新点与x距离更近
                        self.nearest = np.insert(self.nearest, i, [dist, node.value], axis=0)  # 插入比i处距离更小的
                        self.nearest = self.nearest[:-1]
                        break
                # 找到nearest集合里距离最大值的位置,为-1值的个数
                n = list(self.nearest[:, 0]).count(-1)
                # 切分轴的距离比nearest中最大的小(存在相交)
                if self.nearest[-n - 1, 0] > abs(daxis):
                    if daxis < 0:  # 相交,x[axis]< node.data[axis]时,去右边(左边已经遍历了)
                        recurve(node.right)
                    else:  # x[axis]> node.data[axis]时,去左边,(右边已经遍历了)
                        recurve(node.left)
        recurve(root)
        return self.nearest


# 最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited")

data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kd = KDTree(data)

#[3, 4.5]最近的3个点
n = kd.search(kd.root, [3, 4.5], 3)
print(n)

#[[1.8027756377319946 list([2, 3])]
 [2.0615528128088303 list([5, 4])]
 [2.692582403567252 list([4, 7])]]

5.基于sklearn

https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html

from sklearn.neighbors import KDTree
import numpy as np
from sklearn.neighbors import KDTree

np.random.seed(0)
X = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])

tree = KDTree(X, leaf_size=2)
dist, ind = tree.query(X[:1], k=3)

print(dist)  # 3个最近的距离
print(ind)  # 3个最近的索引
print(X[ind])  # 3个最近的点

#
[[0.         3.16227766 4.47213595]]
[[0 1 3]]
[[[2 3]
  [5 4]
  [4 7]]]

 

 

posted @ 2019-02-13 23:58  jj千寻  阅读(251)  评论(0编辑  收藏  举报