k近邻法

文章记录的内容是参加DataWhale的组队学习统计学习方法(第二版)习题解答过程中的笔记与查缺补漏!
参考解答地址k近邻法

1. 参照图3.1,在二维空间中给出实例点,画出 \(k\) 为1和2时的 \(k\) 近邻法构成的空间划分,并对其进行比较,体会 \(k\) 值选择与模型复杂度及预测准确率的关系

解答思路

  • 参照图3.1,使用已给的实例点,采用sklearn的KNeighborsClassifier分类器,对 \(k=1\)\(k=2\) 时的模型进行训练
  • 使用matplotlib的contourf和scatter,画出k为1和2时的k近邻法构成的空间划分
  • 根据模型得到的预测结果,计算预测准确率,并设置图形标题
  • 根据程序生成的图,比较 \(k\) 为1和2时,\(k\) 值选择与模型复杂度、预测准确率的关系

具体的代码可以参考这里
其中一个重点是 \(k\) 的取值与模型复杂度的关系。原本我以为 \(k\) 越小,模型越简单,因为考虑到的邻居更少。但是后来我又看到了书中 p.52 页的一段话:

如果选择较小的 \(k\) 值,就相当于用较小的邻域中的训练实例进行预测,“学习”的近似误差(approximation error)会减小,只有与输入样本较近(相似的)的训练实例才会对预测起作用。但缺点是“学习”的估计误差(estimation error)会增大,预测结果对近邻的实例点会非常敏感。如果近邻的实例点恰巧是噪声,预测就会出错。即,\(k\) 值的减小就意味着整体模型变得复杂,容易发生过拟合
如果选择较大的 \(k\) 值,就相当于用较大的邻域中的训练实例进行预测。其优点是可以减少学习的估计误差。但缺点是学习的近似误差会增大。这时与输入实例较远的(不相似的)训练实例也会对预测起作用,是预测发生错误。\(k\) 值的增大意味着整体模型变得简单。

还有一个重点是 \(k\) 值选择与预测准确率的关系

\(k=1\) 时,模型易产生过拟合,但在过拟合发生前,\(k\) 值越大,预测准确率越低,也反映模型泛化能力越差,模型简单。反之,\(k\) 值越小,预测准确率越高,模型具有更好的泛化能力,模型复杂。

2. 利用例题3.2构造的 \(kd\) 树求点 \(x=(3,4.5)^T\) 的最近邻点

解答思路

  • 方法一:
    • 使用sklearn的KDTree类,结合例题3.2构建平衡kdkd树,配置相关参数(构建平衡树kd树算法,见书中第54页算法3.2内容);
    • 使用tree.query方法,查找(3, 4.5)的最近邻点(搜索kd树算法,见书中第55页第3.3.2节内容);
    • 根据第3步返回的参数,得到最近邻点。
  • 方法二:
      - 根据书中第56页算法3.3用kdkd树的最近邻搜索方法,查找(3, 4.5)的最近邻点

具体过程参考这里

其中一个重要的知识点就是 \(kd\) 树的搜索,根据书中第56页算法3.3(用\(kd\)树的最近邻搜索)

输入:已构造的kd树;目标点\(x\)
输出:\(x\) 的k近邻
(1)\(kd\) 树中找出包含目标点\(x\)的叶结点:从根结点出发,递归地向下访问树。若目标点 \(x\) 当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,直到子结点为叶结点为止;
(2)如果“当前 \(k\) 近邻点集”元素数量小于\(k\)或者叶节点距离小于“当前 \(k\) 近邻点集”中最远点距离,那么将叶节点插入“当前k近邻点集”;
(3)递归地向上回退,在每个结点进行以下操作:
  (a)如果“当前 \(k\) 近邻点集”元素数量小于 \(k\) 或者当前节点距离小于“当前 \(k\) 近邻点集”中最远点距离,那么将该节点插入“当前 \(k\) 近邻点集”。
  (b)检查另一子结点对应的区域是否与以目标点为球心、以目标点与“当前 \(k\) 近邻点集”中最远点间的距离为半径的超球体相交。
  如果相交,可能在另一个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着,递归地进行近邻搜索;
  如果不相交,向上回退;
(4)当回退到根结点时,搜索结束,最后的“当前 \(k\) 近邻点集”即为 \(x\) 的近邻点。

3. 参照算法3.3,写出输出为 \(x\)\(k\) 近邻的算法

解答思路

  • 参考书中第56页算法3.3(用kdkd树的最近邻搜索),写出输出为xx的kk近邻算法;
  • 根据算法步骤,写出算法代码,并用习题3.2的解进行验证。

原文见这里
注意:

  • \(kd\) 树是对 \(k\) 维空间的一个划分(\(k\) 并不是 \(k\) 近邻的 \(k\)),相当于不断用垂直该空间中某个轴的超平面对空间进行划分
  • \(kd\) 树构建过程中,选择切分点的方式多种多样,例如对各维度进行统计,选择方差较大的进行分隔
  • 构建 \(kd\) 树后,数据集中的每个样本对应树中的一个结点并不是只在叶子结点上
  • 若数据集中的样本过多,可以随机从中选择一定数量的样本来选择指定维上的切分点
  • 若数据集中有多个点在当前选择的维上有相同的值,只需要选择其中一个放在刚生成的节点上

参考代码:

import json


class Node:
    """节点类"""

    def __init__(self, value, index, left_child, right_child):
        self.value = value.tolist()
        self.index = index
        self.left_child = left_child
        self.right_child = right_child

    def __repr__(self):
        return json.dumps(self, indent=3, default=lambda obj: obj.__dict__, ensure_ascii=False, allow_nan=False)

class KDTree:
    """kd tree类"""

    def __init__(self, data):
        # 数据集
        self.data = np.asarray(data)
        # kd树
        self.kd_tree = None
        # 创建平衡kd树
        self._create_kd_tree(data)

    def _split_sub_tree(self, data, depth=0):
        # 算法3.2第3步:直到子区域没有实例存在时停止
        if len(data) == 0:
            return None
        # 算法3.2第2步:选择切分坐标轴, 从0开始(书中是从1开始)
        l = depth % data.shape[1]
        # 对数据进行排序
        data = data[data[:, l].argsort()]
        # 算法3.2第1步:将所有实例坐标的中位数作为切分点
        median_index = data.shape[0] // 2
        # 获取结点在数据集中的位置,注意此处是用于切分的结点在完整数据集中的索引!
        node_index = [i for i, v in enumerate(
            self.data) if list(v) == list(data[median_index])]
        return Node(
            # 本结点
            value=data[median_index],
            # 本结点在数据集中的位置
            index=node_index[0],
            # 左子结点
            left_child=self._split_sub_tree(data[:median_index], depth + 1),
            # 右子结点
            right_child=self._split_sub_tree(
                data[median_index + 1:], depth + 1)
        )

    def _create_kd_tree(self, X):
        self.kd_tree = self._split_sub_tree(X)

    def query(self, data, k=1):
        data = np.asarray(data)
        hits = self._search(data, self.kd_tree, k=k, k_neighbor_sets=list())
        dd = np.array([hit[0] for hit in hits])
        ii = np.array([hit[1] for hit in hits])
        return dd, ii

    def __repr__(self):
        return str(self.kd_tree)

    @staticmethod
    def _cal_node_distance(node1, node2):
        """计算两个结点之间的距离"""
        return np.sqrt(np.sum(np.square(node1 - node2)))

    def _search(self, point, tree=None, k=1, k_neighbor_sets=None, depth=0):
        if k_neighbor_sets is None:
            k_neighbor_sets = []
        if tree is None:
            return k_neighbor_sets

        # (1)找到包含目标点x的叶结点
        if tree.left_child is None and tree.right_child is None:
            # 更新当前k近邻点集
            return self._update_k_neighbor_sets(k_neighbor_sets, k, tree, point)

        # 递归地向下访问kd树
        if point[0][depth % k] < tree.value[depth % k]:
            direct = 'left'
            next_branch = tree.left_child
        else:
            direct = 'right'
            next_branch = tree.right_child
        if next_branch is not None:
            # (3)(a) 判断当前结点,并更新当前k近邻点集
            k_neighbor_sets = self._update_k_neighbor_sets(
                k_neighbor_sets, k, next_branch, point)
            # (3)(b)检查另一子结点对应的区域是否相交
            if direct == 'left':
                node_distance = self._cal_node_distance(
                    point, tree.right_child.value)
                if k_neighbor_sets[0][0] > node_distance:
                    # 如果相交,递归地进行近邻搜索
                    return self._search(point, tree=tree.right_child, k=k, depth=depth + 1,
                                        k_neighbor_sets=k_neighbor_sets)
            else:
                node_distance = self._cal_node_distance(
                    point, tree.left_child.value)
                if k_neighbor_sets[0][0] > node_distance:
                    return self._search(point, tree=tree.left_child, k=k, depth=depth + 1,
                                        k_neighbor_sets=k_neighbor_sets)

        return self._search(point, tree=next_branch, k=k, depth=depth + 1, k_neighbor_sets=k_neighbor_sets)

    def _update_k_neighbor_sets(self, best, k, tree, point):
        # 计算目标点与当前结点的距离
        node_distance = self._cal_node_distance(point, tree.value)
        if len(best) == 0:
            best.append((node_distance, tree.index, tree.value))
        elif len(best) < k:
            # 如果“当前k近邻点集”元素数量小于k
            self._insert_k_neighbor_sets(best, tree, node_distance)
        else:
            # 叶节点距离小于“当前 𝑘 近邻点集”中最远点距离
            if best[0][0] > node_distance:
                best = best[1:]
                self._insert_k_neighbor_sets(best, tree, node_distance)
        return best

    @staticmethod
    def _insert_k_neighbor_sets(best, tree, node_distance):
        """将距离最远的结点排在前面"""
        n = len(best)
        for i, item in enumerate(best):
            if item[0] < node_distance:
                # 将距离最远的结点插入到前面
                best.insert(i, (node_distance, tree.index, tree.value))
                break
        if len(best) == n:
            best.append((node_distance, tree.index, tree.value))
# 打印信息
def print_k_neighbor_sets(k, ii, dd):
    if k == 1:
        text = "x点的最近邻点是"
    else:
        text = "x点的%d个近邻点是" % k

    for i, index in enumerate(ii):
        res = X_train[index]
        if i == 0:
            text += str(tuple(res))
        else:
            text += ", " + str(tuple(res))

    if k == 1:
        text += ",距离是"
    else:
        text += ",距离分别是"
    for i, dist in enumerate(dd):
        if i == 0:
            text += "%.4f" % dist
        else:
            text += ", %.4f" % dist

    print(text)
import numpy as np

X_train = np.array([[2, 3],
                    [5, 4],
                    [9, 6],
                    [4, 7],
                    [8, 1],
                    [7, 2]])
kd_tree = KDTree(X_train)
# 设置k值
k = 1
# 查找邻近的结点
dists, indices = kd_tree.query(np.array([[3, 4.5]]), k=k)
# 打印邻近结点
print_k_neighbor_sets(k, indices, dists)
x点的最近邻点是(2, 3),距离是1.8028
# 打印kd树
kd_tree
{
   "value": [
      7,
      2
   ],
   "index": 5,
   "left_child": {
      "value": [
         5,
         4
      ],
      "index": 1,
      "left_child": {
         "value": [
            2,
            3
         ],
         "index": 0,
         "left_child": null,
         "right_child": null
      },
      "right_child": {
         "value": [
            4,
            7
         ],
         "index": 3,
         "left_child": null,
         "right_child": null
      }
   },
   "right_child": {
      "value": [
         9,
         6
      ],
      "index": 2,
      "left_child": {
         "value": [
            8,
            1
         ],
         "index": 4,
         "left_child": null,
         "right_child": null
      },
      "right_child": null
   }
}
posted @ 2021-12-21 14:20  Milkha  阅读(311)  评论(0编辑  收藏  举报