k近邻法
文章记录的内容是参加DataWhale的组队学习统计学习方法(第二版)习题解答过程中的笔记与查缺补漏!
参考解答地址:k近邻法。
1. 参照图3.1,在二维空间中给出实例点,画出 \(k\) 为1和2时的 \(k\) 近邻法构成的空间划分,并对其进行比较,体会 \(k\) 值选择与模型复杂度及预测准确率的关系
解答思路:
- 参照图3.1,使用已给的实例点,采用sklearn的KNeighborsClassifier分类器,对 \(k=1\) 和 \(k=2\) 时的模型进行训练
- 使用matplotlib的contourf和scatter,画出k为1和2时的k近邻法构成的空间划分
- 根据模型得到的预测结果,计算预测准确率,并设置图形标题
- 根据程序生成的图,比较 \(k\) 为1和2时,\(k\) 值选择与模型复杂度、预测准确率的关系
具体的代码可以参考这里。
其中一个重点是 \(k\) 的取值与模型复杂度的关系。原本我以为 \(k\) 越小,模型越简单,因为考虑到的邻居更少。但是后来我又看到了书中 p.52 页的一段话:
如果选择较小的 \(k\) 值,就相当于用较小的邻域中的训练实例进行预测,“学习”的近似误差(approximation error)会减小,只有与输入样本较近(相似的)的训练实例才会对预测起作用。但缺点是“学习”的估计误差(estimation error)会增大,预测结果对近邻的实例点会非常敏感。如果近邻的实例点恰巧是噪声,预测就会出错。即,\(k\) 值的减小就意味着整体模型变得复杂,容易发生过拟合。
如果选择较大的 \(k\) 值,就相当于用较大的邻域中的训练实例进行预测。其优点是可以减少学习的估计误差。但缺点是学习的近似误差会增大。这时与输入实例较远的(不相似的)训练实例也会对预测起作用,是预测发生错误。\(k\) 值的增大意味着整体模型变得简单。
还有一个重点是 \(k\) 值选择与预测准确率的关系:
当 \(k=1\) 时,模型易产生过拟合,但在过拟合发生前,\(k\) 值越大,预测准确率越低,也反映模型泛化能力越差,模型简单。反之,\(k\) 值越小,预测准确率越高,模型具有更好的泛化能力,模型复杂。
2. 利用例题3.2构造的 \(kd\) 树求点 \(x=(3,4.5)^T\) 的最近邻点
解答思路:
- 方法一:
- 使用sklearn的KDTree类,结合例题3.2构建平衡kdkd树,配置相关参数(构建平衡树kd树算法,见书中第54页算法3.2内容);
- 使用tree.query方法,查找(3, 4.5)的最近邻点(搜索kd树算法,见书中第55页第3.3.2节内容);
- 根据第3步返回的参数,得到最近邻点。
- 方法二:
- 根据书中第56页算法3.3用kdkd树的最近邻搜索方法,查找(3, 4.5)的最近邻点
具体过程参考这里。
其中一个重要的知识点就是 \(kd\) 树的搜索,根据书中第56页算法3.3(用\(kd\)树的最近邻搜索)
输入:已构造的kd树;目标点\(x\);
输出:\(x\) 的k近邻
(1)在 \(kd\) 树中找出包含目标点\(x\)的叶结点:从根结点出发,递归地向下访问树。若目标点 \(x\) 当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,直到子结点为叶结点为止;
(2)如果“当前 \(k\) 近邻点集”元素数量小于\(k\)或者叶节点距离小于“当前 \(k\) 近邻点集”中最远点距离,那么将叶节点插入“当前k近邻点集”;
(3)递归地向上回退,在每个结点进行以下操作:
(a)如果“当前 \(k\) 近邻点集”元素数量小于 \(k\) 或者当前节点距离小于“当前 \(k\) 近邻点集”中最远点距离,那么将该节点插入“当前 \(k\) 近邻点集”。
(b)检查另一子结点对应的区域是否与以目标点为球心、以目标点与“当前 \(k\) 近邻点集”中最远点间的距离为半径的超球体相交。
如果相交,可能在另一个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着,递归地进行近邻搜索;
如果不相交,向上回退;
(4)当回退到根结点时,搜索结束,最后的“当前 \(k\) 近邻点集”即为 \(x\) 的近邻点。
3. 参照算法3.3,写出输出为 \(x\) 的 \(k\) 近邻的算法
解答思路:
- 参考书中第56页算法3.3(用kdkd树的最近邻搜索),写出输出为xx的kk近邻算法;
- 根据算法步骤,写出算法代码,并用习题3.2的解进行验证。
原文见这里。
注意:
- \(kd\) 树是对 \(k\) 维空间的一个划分(\(k\) 并不是 \(k\) 近邻的 \(k\)),相当于不断用垂直该空间中某个轴的超平面对空间进行划分
- \(kd\) 树构建过程中,选择切分点的方式多种多样,例如对各维度进行统计,选择方差较大的进行分隔
- 构建 \(kd\) 树后,数据集中的每个样本对应树中的一个结点,并不是只在叶子结点上
- 若数据集中的样本过多,可以随机从中选择一定数量的样本来选择指定维上的切分点
- 若数据集中有多个点在当前选择的维上有相同的值,只需要选择其中一个放在刚生成的节点上
参考代码:
import json
class Node:
"""节点类"""
def __init__(self, value, index, left_child, right_child):
self.value = value.tolist()
self.index = index
self.left_child = left_child
self.right_child = right_child
def __repr__(self):
return json.dumps(self, indent=3, default=lambda obj: obj.__dict__, ensure_ascii=False, allow_nan=False)
class KDTree:
"""kd tree类"""
def __init__(self, data):
# 数据集
self.data = np.asarray(data)
# kd树
self.kd_tree = None
# 创建平衡kd树
self._create_kd_tree(data)
def _split_sub_tree(self, data, depth=0):
# 算法3.2第3步:直到子区域没有实例存在时停止
if len(data) == 0:
return None
# 算法3.2第2步:选择切分坐标轴, 从0开始(书中是从1开始)
l = depth % data.shape[1]
# 对数据进行排序
data = data[data[:, l].argsort()]
# 算法3.2第1步:将所有实例坐标的中位数作为切分点
median_index = data.shape[0] // 2
# 获取结点在数据集中的位置,注意此处是用于切分的结点在完整数据集中的索引!
node_index = [i for i, v in enumerate(
self.data) if list(v) == list(data[median_index])]
return Node(
# 本结点
value=data[median_index],
# 本结点在数据集中的位置
index=node_index[0],
# 左子结点
left_child=self._split_sub_tree(data[:median_index], depth + 1),
# 右子结点
right_child=self._split_sub_tree(
data[median_index + 1:], depth + 1)
)
def _create_kd_tree(self, X):
self.kd_tree = self._split_sub_tree(X)
def query(self, data, k=1):
data = np.asarray(data)
hits = self._search(data, self.kd_tree, k=k, k_neighbor_sets=list())
dd = np.array([hit[0] for hit in hits])
ii = np.array([hit[1] for hit in hits])
return dd, ii
def __repr__(self):
return str(self.kd_tree)
@staticmethod
def _cal_node_distance(node1, node2):
"""计算两个结点之间的距离"""
return np.sqrt(np.sum(np.square(node1 - node2)))
def _search(self, point, tree=None, k=1, k_neighbor_sets=None, depth=0):
if k_neighbor_sets is None:
k_neighbor_sets = []
if tree is None:
return k_neighbor_sets
# (1)找到包含目标点x的叶结点
if tree.left_child is None and tree.right_child is None:
# 更新当前k近邻点集
return self._update_k_neighbor_sets(k_neighbor_sets, k, tree, point)
# 递归地向下访问kd树
if point[0][depth % k] < tree.value[depth % k]:
direct = 'left'
next_branch = tree.left_child
else:
direct = 'right'
next_branch = tree.right_child
if next_branch is not None:
# (3)(a) 判断当前结点,并更新当前k近邻点集
k_neighbor_sets = self._update_k_neighbor_sets(
k_neighbor_sets, k, next_branch, point)
# (3)(b)检查另一子结点对应的区域是否相交
if direct == 'left':
node_distance = self._cal_node_distance(
point, tree.right_child.value)
if k_neighbor_sets[0][0] > node_distance:
# 如果相交,递归地进行近邻搜索
return self._search(point, tree=tree.right_child, k=k, depth=depth + 1,
k_neighbor_sets=k_neighbor_sets)
else:
node_distance = self._cal_node_distance(
point, tree.left_child.value)
if k_neighbor_sets[0][0] > node_distance:
return self._search(point, tree=tree.left_child, k=k, depth=depth + 1,
k_neighbor_sets=k_neighbor_sets)
return self._search(point, tree=next_branch, k=k, depth=depth + 1, k_neighbor_sets=k_neighbor_sets)
def _update_k_neighbor_sets(self, best, k, tree, point):
# 计算目标点与当前结点的距离
node_distance = self._cal_node_distance(point, tree.value)
if len(best) == 0:
best.append((node_distance, tree.index, tree.value))
elif len(best) < k:
# 如果“当前k近邻点集”元素数量小于k
self._insert_k_neighbor_sets(best, tree, node_distance)
else:
# 叶节点距离小于“当前 𝑘 近邻点集”中最远点距离
if best[0][0] > node_distance:
best = best[1:]
self._insert_k_neighbor_sets(best, tree, node_distance)
return best
@staticmethod
def _insert_k_neighbor_sets(best, tree, node_distance):
"""将距离最远的结点排在前面"""
n = len(best)
for i, item in enumerate(best):
if item[0] < node_distance:
# 将距离最远的结点插入到前面
best.insert(i, (node_distance, tree.index, tree.value))
break
if len(best) == n:
best.append((node_distance, tree.index, tree.value))
# 打印信息
def print_k_neighbor_sets(k, ii, dd):
if k == 1:
text = "x点的最近邻点是"
else:
text = "x点的%d个近邻点是" % k
for i, index in enumerate(ii):
res = X_train[index]
if i == 0:
text += str(tuple(res))
else:
text += ", " + str(tuple(res))
if k == 1:
text += ",距离是"
else:
text += ",距离分别是"
for i, dist in enumerate(dd):
if i == 0:
text += "%.4f" % dist
else:
text += ", %.4f" % dist
print(text)
import numpy as np
X_train = np.array([[2, 3],
[5, 4],
[9, 6],
[4, 7],
[8, 1],
[7, 2]])
kd_tree = KDTree(X_train)
# 设置k值
k = 1
# 查找邻近的结点
dists, indices = kd_tree.query(np.array([[3, 4.5]]), k=k)
# 打印邻近结点
print_k_neighbor_sets(k, indices, dists)
x点的最近邻点是(2, 3),距离是1.8028
# 打印kd树
kd_tree
{
"value": [
7,
2
],
"index": 5,
"left_child": {
"value": [
5,
4
],
"index": 1,
"left_child": {
"value": [
2,
3
],
"index": 0,
"left_child": null,
"right_child": null
},
"right_child": {
"value": [
4,
7
],
"index": 3,
"left_child": null,
"right_child": null
}
},
"right_child": {
"value": [
9,
6
],
"index": 2,
"left_child": {
"value": [
8,
1
],
"index": 4,
"left_child": null,
"right_child": null
},
"right_child": null
}
}