ML-KDTree思想、划分、实现
1.概念
kd树是一种对k维空间中的实例进行存储以便快速检索的二叉树形结构。构造kd树相当于不断用垂直于坐标轴的超平面对k维空间切分,构成一系列k维超矩形区域。每个节点对应于k维超矩形区域。
所有非叶子节点可以视作用一个超平面把空间分区成两个半空间。节点左边的子树代表在超平面左边的点,节点右边的子树代表在超平面右边的点。
如果选择按照x轴划分,所有x值小于指定值的节点都会出现在左子树,所有x值大于指定值的节点都会出现在右子树。
假设二维,数据集T={(7,2),(5,4),(2,3),(4,7),(9,6),(8,1)}
2.轴的划分
1)轮流对轴进行划分,如二维,轮流对x,y划分
2)基于轴上方差最大的轴划分,这样划分区分度更大,如计算x上(7、5、2、4、9、8),y轴上(2、4、3、7、6、1)值的方差,取最大的作为划分轴。
3.生成
选定轴后,取轴的中点数字为划分点,如选定x轴:(7、5、2、4、9、8),然后中点取7,则用(7,2)点作为划分,左子树数据上x轴小于7,右子树x值大于=7,的2个数据集划分
如图,1次轴划分后
最终不断基于轴划分,然后即可产生KD树:
4.KD树的查找
KDTree通常用在KNN算法等地方,寻找某个数据点最近邻的k个点。通过构造KDTree,可以快速的查找数据点的k个最近点。
python创建:
class KDNode(object):
def __init__(self, value, split, left, right):
# value=[x,y]
self.value = value
self.split = split
self.right = right
self.left = left
class KDTree(object):
def __init__(self, data):
# data=[[x1,y1],[x2,y2]...,]
# 维度
k = len(data[0])
def CreateNode(split, data_set):
if not data_set:
return None
data_set.sort(key=lambda x: x[split])
# 整除2
split_pos = len(data_set) // 2
median = data_set[split_pos]
split_next = (split + 1) % k
return KDNode(median, split, CreateNode(split_next, data_set[: split_pos]),
CreateNode(split_next, data_set[split_pos + 1:]))
self.root = CreateNode(0, data)
查找:
def search(self, root, x, count=1):
nearest = []
for i in range(count):
nearest.append([-1, None])
self.nearest = np.array(nearest)
def recurve(node):
if node is not None:
axis = node.split
daxis = x[axis] - node.value[axis]
if daxis < 0:
recurve(node.left)
else:
recurve(node.right)
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.value)))
for i, d in enumerate(self.nearest):
if d[0] < 0 or dist < d[0]: # 如果当前nearest内i处未标记(-1),或者新点与x距离更近
self.nearest = np.insert(self.nearest, i, [dist, node.value], axis=0) # 插入比i处距离更小的
self.nearest = self.nearest[:-1]
break
# 找到nearest集合里距离最大值的位置,为-1值的个数
n = list(self.nearest[:, 0]).count(-1)
# 切分轴的距离比nearest中最大的小(存在相交)
if self.nearest[-n - 1, 0] > abs(daxis):
if daxis < 0: # 相交,x[axis]< node.data[axis]时,去右边(左边已经遍历了)
recurve(node.right)
else: # x[axis]> node.data[axis]时,去左边,(右边已经遍历了)
recurve(node.left)
recurve(root)
return self.nearest
# 最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited")
data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kd = KDTree(data)
#[3, 4.5]最近的3个点
n = kd.search(kd.root, [3, 4.5], 3)
print(n)
#[[1.8027756377319946 list([2, 3])]
[2.0615528128088303 list([5, 4])]
[2.692582403567252 list([4, 7])]]
5.基于sklearn
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html
from sklearn.neighbors import KDTree
import numpy as np
from sklearn.neighbors import KDTree
np.random.seed(0)
X = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
tree = KDTree(X, leaf_size=2)
dist, ind = tree.query(X[:1], k=3)
print(dist) # 3个最近的距离
print(ind) # 3个最近的索引
print(X[ind]) # 3个最近的点
#
[[0. 3.16227766 4.47213595]]
[[0 1 3]]
[[[2 3]
[5 4]
[4 7]]]