Python 机器学习 K-近邻算法 KD树
1、构建KD树
在k维空间中的数据点集合中构建kd树,通常通过递归方式进行。在每一层递归中,选择一个维度并在该维度上找到中位数点,以此点为界将数据集分割成两部分,递归地在每部分上重复这一过程。k-d树每个节点都是k维数值点,而且每个非叶节点可以被认为是用一个轴-垂直的超平面将空间分割成两半。KD树的构建和查询效率与数据分布有关。对于某些特殊的数据分布,KD树可能不比简单的线性搜索快。维数据(维度的数量很大)可能会导致KD树效率降低,原因是由于“维度的诅咒”。在这些情况下,可能需要考虑其他的数据结构或算法。
从k个维度中选择一个轴。通常,KD树的构建从第一个维度开始,然后沿着树的每个分支,依次选择下一个维度。在每个更深的树级别上,选择下一个维度。这通常通过模运算实现,有些实现选择方差最大的维度作为切分轴。这种方法基于这样的假设:方差大的维度可能更有利于将数据分得更均匀。某些高级的实现会分析数据在不同维度上的分布,选择最能够提高树效率的维度进行切分。在选择的轴上找到中位数作为分割点,把数据分割成两部分。可以有效地将搜索空间减少到与查询点更接近的区域。对每个子集重复上述过程,选择轴并找到分割点,直到满足终止条件,如子集大小低于预设阈值或达到树的最大深度。每个分割点成为树的一个节点,具有两个子节点:一个代表左子集,另一个代表右子集。这样递归构造下去,直到叶节点。
import json
class Node:
def __init__(self, point, left=None, right=None):
self.point = point
self.left = left
self.right = right
def build_kdtree(points, depth=0):
n = len(points)
if n == 0:
return None
k = len(points[0]) # 假设所有点在K维空间
axis = depth % k # 选择轴
points.sort(key=lambda x: x[axis])
median_index = n // 2
return Node(
point=points[median_index],
left=build_kdtree(points[:median_index], depth + 1),
right=build_kdtree(points[median_index + 1:], depth + 1)
)
def kd_tree_to_dict(node):
if node is None:
return None
return {
"point": node.point,
"left": kd_tree_to_dict(node.left),
"right": kd_tree_to_dict(node.right)
}
# 构建KD树
points = [(3, 6), (17, 15), (13, 15), (6, 12), (9, 1), (2, 7), (10, 19)]
kdtree = build_kdtree(points)
# 将KD树转换为字典
kdtree_dict = kd_tree_to_dict(kdtree)
# 将字典转换为JSON字符串
kdtree_json = json.dumps(kdtree_dict, indent=4)
print(kdtree_json)
2、搜索最近邻
K-近邻算法(KNN)中,KD树可以用来高效地搜索最近邻。从根节点开始,向下遍历树直到叶子节点,同时记录下路径上的节点。在回溯过程中,检查其他子树是否有可能包含更近的点。如果可能,搜索那个子树。选择距离查询点最近的点作为最近邻。
from sklearn.neighbors import KDTree
import numpy as np
# 示例数据
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
# 构建KD树
kdtree = KDTree(X, leaf_size=2)
# 查询点
query_point = np.array([[2, 3]])
# 执行最近邻搜索
dist, ind = kdtree.query(query_point, k=1)
print("Nearest point:", X[ind])
print("Distance:", dist)
3、 kd树优势
相比于逐一计算所有数据点的距离,kd树通过减少需要计算的距离数来提高效率,特别是在数据维度不是非常高的情况下。KD树是一种非常有效的结构,特别是在处理具有少量维度(如几十个维度或更少)的数据集时。然而,它的性能可能会随着维度的增加而下降,这种现象称为“维度的诅咒”。在高维数据集上,其他方法(如基于图的方法或近似算法)可能更有效。
4、scikit-learn 中的使用
在scikit-learn库中,KD树通常在后台自动构建,以加速KNN查询。如使用KNeighborsClassifier或KNeighborsRegressor时,可以通过设置algorithm='kd_tree'来指定使用KD树。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化KNN模型,使用KD树
knn = KNeighborsClassifier(n_neighbors=3, algorithm='kd_tree')
# 训练模型
knn.fit(X_train, y_train)
# 进行预测
y_pred = knn.predict(X_test)
# 打印结果
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
详细文档:Python 机器学习 K-近邻算法