学习笔记-《统计学习方法》-第三章-k近邻

3 k近邻法

k近邻法的输入为实例的特征向量,对应于特征空间的点;输出为实例的类别,可以取多类。k近邻法假定给定一个训练数据集,其中的实例类别一定。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测,因此不具备显式的学习过程,实际上利用训练数据集对特征向量空间进行划分,并作为其分类的“模型”。

3.1 k近邻算法

输入:训练数据集

T=(x1,y1),(x2,y2),...,(xN,yN)

其中,xiXRn为实例的特征向量,yiY={c1,c2,...,cK}为实例的类别,i=1,2,..N

输出:实例x所属的类y

(1)根据给定的距离度量,在训练集T中找出与x最邻近的k个点,涵盖这k个点的邻域记做Nk(x);

(2)在Nk(x)中根据分类决策规则(如多数表决)决定x的类别y

(3.1)y=argmaxcjxiNk(x)I(yi=cj),i=1,2,...,N;j=1,2,...,K

k=1时,是特殊情形,称为最近邻算法。

3.2 k近邻模型

3.2.1 模型

特征空间中,对每个训练实例点xi,距离该点比其它点更近的所有点组成一个区域,叫做单元。每个训练实例点拥有一个单元,所有训练实例点的单元构成对特征空间的一个划分。最近邻法将实例xi的类yi作为其单元中所有点的类标记。

涉及几个概念

3.2.2 距离度量

Lp距离(Minkowski距离)

Lp(xi,xj)=(l=1n|xi(l)xj(l)|p)1p

p=2时,即是欧氏距离。

p=1时,是曼哈顿距离(Manhattan distance)

L1(xi,xj)=(l=1n|xi(l)xj(l)|)

p=时,是各个坐标距离的最大值,又叫切比雪夫距离(Chebyshev distance)

L(xi,xj)=maxl|xi(l)xj(l)|

3.2.3 k值选择

k值减小,意味着整体模型变复杂,更容易过拟合

k值增大,意味着使用较大的邻域进行预测,减少估计误差,但会增加近似误差

一般使用交叉验证法确认

3.2.4 分类决策规则

一般使用多数表决(majority voting rule),等价于经验风险最小。

3.3 kd树

使用k近邻法,主要考虑的问题是如何快速进行近邻搜索。最简单的实现方法就是线性扫描,计算输入实例与每一个训练实例的距离,训练集较大的时,耗时巨大,基本不可行。一般采用kd树

3.3.1构造kd树

kd树是一种二叉树,是将k维空间中的实例进行存储,从而方便快速检索的一种方式。

方法如下:

  1. 构造根节点,使根节点对应于包含所有实例点的超矩形区域
  2. 递归进行如下操作
    1. 不断对k维空间划分,生成子结点,在超矩形区域(结点)上选择一个坐标轴,以及在这个坐标轴上的一个切分点,确定一个超平面,超平面通过选定的切分点,并垂直于选定的坐标轴,将当前的超矩形区域划分为左右两个子区域。
    2. 以上过程持续进行,直到没有实例时(即,所有的实例都被划分为叶结点)结束。

算法3.2

输入:k维空间数据集T={x1,x2,...xN},其中xi=(xi(1),xi(2),...,xi(k))Ti=1,2,...N

输出:kd树

  1. 构建根节点,根节点对应于包含Tk维空间的超矩形区域。

    x(1)为坐标轴,将所有实例的x(1)坐标的中位数作为切分点,将根节点对应的超矩形区域切分为两个子区域,左子区域对应在x(1)上小于切分点的子区域,右子区域对应大于切分点的子区域。

    将落在切分超平面的实例点保存在根结点(所以不一定用的中位数,其实还是要实例点)

  2. 重复:对深度为j的结点,选择x(l)作为切分的坐标轴,l=(j mod k)+1,这只是一个确定切分坐标轴的方式,保证一定有一个k维中的维度被选中。

    对区域中所有实例在该维度上的值求中位数,作为新的切分点,将区域再次划分为两个子区域

    同样,将落在切分超平面上的实例点保存在该结点。

  3. 直到两个子区域没有实例存在时,停止。

3.3.2搜索kd树

算法3.3

输入:已构造的kd树,目标点x;

输出:x的最近邻

  1. 首先找出包含目标点x的叶结点

    从根节点出发,递归向下,当目标点当前维的坐标小于切分点坐标,则移动到左子结点,否则,移动到右子结点,知道子结点为叶结点时停止。

  2. 以此叶结点为“当前最近点”

  3. 递归回退,在每个结点中进行以下操作

    1. 如果该结点保存的实例点比当前最近点离目标点更近,则以该实例点为“当前最近点”

    2. 当前最近点一定存在于某个父结点对应的子结点区域内,所以要检查该父结点对应的另一个子结点区域内是否有更近的节点。

      具体方法是,检查目标点与当前最近点所构成的超球体,是否与另一结点所在区域相交,实际的计算方式,就是看这个超球体是否经过父结点所形成的分割超平面,如果结果,代表该区域与另一结点相交。

      • 如果相交,那么在另一结点内,可能存在一个距目标结点更近的点,所以移动到另一个结点,递归进行最近邻搜索(先找出包含目标点x的叶结点,然后递归回退);
      • 如果不想交,则直接向上回退
  4. 当回退到根结点时,搜索结束,最后的“当前最近点”即为x的最近邻点

如果实例点是随机分布的,kd树搜索的平均计算复杂度是O(logN),主要使用于训练实例数远大于空间维数的情况,如果空间维数接近训练实例数,它的效率就会迅速下降,几乎接近线性扫描。

kd树搜索的好处是,如果某结点的分割面到当前最近点的距离大于当前最小距离,则该节点的另一侧结点完全不需要进行遍历,省掉了一部分结点的遍历时间。

代码实现:

#!/usr/bin/env python
# -*-coding:utf-8 -*-
'''
@File    :   knn.py
@Time    :   2022/04/23 11:37:39
@Author  :   zoro
@Version :   1.0
@Desc    :   k近邻实现
'''


from logging import root
from turtle import right
import numpy as np
import queue


class KDTreeNode():
    """
    kd tree的节点类
    """
    def __init__(self):
        # 左子节点
        self.left = None 
        # 右子节点
        self.right = None 
        # 父结点
        self.father = None 
        # 使用的特征索引
        self.feature_index = None 
        # 节点对应的x, y
        self.val = None 
        # 节点在树中所处的层级
        self.layer = None
    
    def __str__(self):
        return f"feature: {self.feature_index}, split: {self.val}"
    
    def brother(self):
        """_summary_

        Returns:
            _type_: 输出节点的兄弟节点,即同属一个父结点的另一个子结点
        """
        if self.father is None:
            ret = None
        else:
            if self.father.left is self:
                ret = self.father.right
            else:
                ret = self.father.left
        return ret


class KDTree():
    def __init__(self) -> None:
        """生成一颗kd树,首先生成一个空的根节点
        """
        self.root = KDTreeNode()

    def _pre_order_traverse(self, nd):
        """利用递归前序遍历kd树,好处是一个函数,比较简单的完成,尴尬的地方是,不存在一个返回值,不能直接在str里调用

        Args:
            nd (_type_): 开始节点 
        """
        if nd is not None:
            if nd.father is None:
                print(f"-1 -> {nd.layer}: feature_index: {nd.feature_index}, split_x: {nd.val}")
            else:
                print(f"{nd.father.layer} -> {nd.layer}: feature_index: {nd.feature_index}, split_x: {nd.val}")
            self._pre_order_traverse(nd.left)
            self._pre_order_traverse(nd.right)
            
    
    def __str__(self) -> str:
        """打印kd树,前序遍历

        Returns:
            str: _description_
        """
        ret = []
        i = 0
        que = [(self.root, -1)]
        while que:
            nd, idx_father = que.pop(0)
            ret.append("%d -> %d: %s" % (idx_father, i, str(nd)))
            if nd.left is not None:
                que.append((nd.left, i))
            if nd.right is not None:
                que.append((nd.right, i))
            i += 1
        return "\n".join(ret)

    def _get_median_index(self, X, index_list, feature_index) -> int: 
        """统计一系列索引数据中,在对应feature_index维度的中位数

        Args:
            X: 原始数据
            index_list (_type_): 数据索引列表 
            feature_index (_type_): 特征维度索引

        Returns:
            int: 中位数对应的索引
        """
        data_len = len(index_list)
        # 因为不需要真正的中位数,需要将接近中位数的点作为结点使用
        # 为了和案例保持一致,所以k的取值处理一下
        k = data_len // 2 if data_len % 2 != 0 else data_len // 2 - 1
        # k = data_len // 2
        col = list(map(lambda i: (i, X[feature_index][i]), index_list))
        sorted_index = list(map(lambda x: x[0], sorted(col, key=lambda x: x[1])))
        median_index = sorted_index[k] 
        return median_index
    
    def _split_feature_index(self, X, feature_index, index_list, median_index):
        """将index分为左右两部分,因为排序后再划分,要遍历两遍数据,所以直接依据大小过滤

        Args:
            index_list (_type_): _description_
            median_index (_type_): _description_
        
        Returns:
            list: 
        """
        left_index_list = []
        right_index_list = []
        median_val = X[feature_index][median_index]

        for idx in index_list:
            if median_index == idx:
                continue
            idx_val = X[feature_index][idx] 
            if idx_val > median_val:
                right_index_list.append(idx)
            else:
                left_index_list.append(idx)

        return left_index_list, right_index_list

    def build_tree(self, X, y):
        current_node = self.root
        current_node.layer = 0 
        feature_num = len(X)
        index_list = list(range(len(X.T)))
        index_queue = [(current_node, index_list)] 
        while len(index_queue) != 0:
            # 当队列中存在数据时,就进行迭代循环
            # 取出队列中第一个元素
            current_node, index_list = index_queue.pop(0) 
            # 确定要使用的特征,默认就是0层开始,直接mod取余数,就是0?默认使用第一个特征分割?
            # 此处专门使用的是从1开始的序号,而不是从0开始的索引
            # 原公式是(j mod k) + 1
            feature = current_node.layer % feature_num + 1
            feature_index = feature - 1
            # 依据选用的特征,切分数据集
            median_index = self._get_median_index(X, index_list, feature_index)
            # 切分点即为结点的value
            current_node.val = (X.T[median_index], y[median_index])
            # 切分点的特征索引,用第一个特征,实际上使用的索引
            current_node.feature_index = feature_index
            # 依据中位点,将数据切分为两部分
            # 此处本质上,还是一个前序遍历
            left_index_list, right_index_list = self._split_feature_index(X, feature_index, index_list, median_index) 
            if left_index_list != []:
                current_node.left = KDTreeNode()
                current_node.left.father = current_node
                current_node.left.layer = current_node.layer + 1
                # 压入栈
                index_queue.append((current_node.left, left_index_list))
            if right_index_list != []:
                current_node.right = KDTreeNode()
                current_node.right.father = current_node
                current_node.right.layer = current_node.layer + 1
                # 压入栈
                index_queue.append((current_node.right, right_index_list))
    
    def _search_tree(self, target_x, current_node):
        """搜索目标点的最近邻点

        Args:
            target_val (_type_): 目标节点
        """
        while current_node.left or current_node.right:
            if current_node.left is None:
                current_node = current_node.right
            elif current_node.right is None:
                current_node = current_node.left
            else:
                if target_x[current_node.feature_index] < current_node.val[0][current_node.feature_index]:
                    current_node = current_node.left
                else:
                    current_node = current_node.right
        
        return current_node
    
    def _get_eu_dist(self, node_1, node_2):
        """计算节点间的欧式距离

        Args:
            node_1 (_type_): _description_
            node_2 (_type_): _description_
        """
        # eu_dist = np.linalg.norm(node_1 - node_2)
        sum_of_square = sum(map(lambda x, y: (x - y) ** 2, node_1, node_2))
        eu_dist = np.sqrt(sum_of_square)
        return eu_dist


    def _get_dist_with_hyper(self, target_node, node) -> int:
        """计算目标节点,到某一个节点的分界面之间的距离

        Args:
            target_node (_type_): 目标节点
            node (_type_): 某一节点

        Returns:
            int: 距离
        """
        node_feature_value = node.val[0][node.feature_index]
        target_node_feature_value = target_node[node.feature_index]
        return np.sqrt((node_feature_value - target_node_feature_value) ** 2)


    def nearest_neighbor_search(self, target_x):
        """给定一个目标点,搜索其最近邻
        1. 首先找到叶结点
        2. 以叶结点为当前最近点
        3. 递归回退,在每个结点进行以下操作:
            (a) 如果该结点保存的实例点比当前更近,则更新当前最近点
            (b) 当前最近点一定存在于某结点的子结点对应的区域,检查该子结点对应的父结点的另一子结点的区域是否有更近的点

        Args:
            target_x (_type_): _description_
        """
        best_dist = float('inf')
        # 寻找包含目标结点的叶结点
        current_nearest_node = self._search_tree(target_x, self.root)
        traversed_node = []
        que = [(self.root, current_nearest_node)]
        traversed_node = []
        # 递归向上查找 
        while len(que) != 0:
            root_node, current_node = que.pop(0) 
            traversed_node.append(current_node)
            print(current_node)
            while 1:
                dist = self._get_eu_dist(target_x, current_node.val[0])
                # 首先判断当前节点是否离目标节点更近,如果更近,则将当前节点更新为目标节点
                if dist < best_dist:
                    best_dist = dist
                    current_nearest_node = current_node
                # 如果不是更近,判断当前节点是否是根节点,如果是根节点,代表已经搜索完毕,直接跳出循环
                if current_node is not self.root:
                    # 如果不是根节点,代表当前节点可能还存在兄弟节点
                    bro_node = current_node.brother()
                    if bro_node is not None and bro_node not in traversed_node:
                        # 如果存在兄弟节点,则判断目标节点与当前节点的父结点的分割面之间的距离,
                        # 如果距离小于当前最近距离
                        #    代表当前节点的兄弟节点区域,可能存在一个比当前节点离目标节点更近的点,直接搜索兄弟节点的叶子节点,并向上返回
                        # 如果距离大于等于当前最近距离
                        #    代表兄弟节点所构成区域,不存在比当前节点更近的节点,所以直接向上返回
                        dist_with_hyper = self._get_dist_with_hyper(target_x, current_node.father)
                        print(dist_with_hyper)
                        if dist > dist_with_hyper:
                            new_nearest_node = self._search_tree(target_x, bro_node)
                            print(f'new_nearest_node: {new_nearest_node}')
                            que.append((bro_node, new_nearest_node))
                            # 找到一个新的节点,以该节点开始重新递归
                            break
                    # 只要当前节点不是根节点,并且当前节点的兄弟节点分支不存在更近点,那就向上返回,一直到返回到父结点后,开始从下一个叶子节点开始向上递归
                    current_node = current_node.father
                    traversed_node.append(current_node)
                else:
                    break
        return current_nearest_node
                    
    def k_nearest_neighbor_search(self, target_x, k):
        """给定一个目标点,搜索其最近邻
        1. 首先找到叶结点
        2. 以叶结点为当前最近点
        3. 递归回退,在每个结点进行以下操作:
            (a) 如果该结点保存的实例点比当前更近,则更新当前最近点
            (b) 当前最近点一定存在于某结点的子结点对应的区域,检查该子结点对应的父结点的另一子结点的区域是否有更近的点

        Args:
            target_x (_type_): _description_
            k: 查找最近邻的k个节点
        """
        best_dist = float('inf')
        # 寻找包含目标结点的叶结点
        current_nearest_node = self._search_tree(target_x, self.root)

        # 构建一个长度为k的最近邻节点list
        k_nearest_node = []
        
        que = [(self.root, current_nearest_node)]
        traversed_node = []
        # 递归向上查找 
        while len(que) != 0:
            root_node, current_node = que.pop(0) 
            traversed_node.append(current_node)
            print(current_node)
            while 1:
                dist = self._get_eu_dist(target_x, current_node.val[0])
                # 首先判断当前节点是否离目标节点更近,如果更近,则将当前节点更新为目标节点

                if len(k_nearest_node) < k:
                    print(f'insert new node {current_node}')
                    k_nearest_node.append([current_node, dist])
                    best_dist = max([x[1] for x in k_nearest_node]) 

                else:
                    if dist < best_dist:
                        print(f'insert new node {current_node}')
                        k_nearest_node.append([current_node, dist])
                        # 过滤掉大于dist的node,可能会同时过滤掉两个,排序后丢掉最后一个,可以保护原来的顺序
                        k_nearest_node.sort(key = lambda x: x[1])
                        k_nearest_node.pop()
                        best_dist = k_nearest_node[-1][1]

                # 如果不是更近,判断当前节点是否是根节点,如果是根节点,代表已经搜索完毕,直接跳出循环
                if current_node is not self.root:
                    # 如果不是根节点,代表当前节点可能还存在兄弟节点
                    bro_node = current_node.brother()
                    if bro_node is not None and bro_node not in traversed_node:
                        # 如果存在兄弟节点,则判断目标节点与当前节点的父结点的分割面之间的距离,
                        # 如果距离小于当前最近距离
                        #    代表当前节点的兄弟节点区域,可能存在一个比当前节点离目标节点更近的点,直接搜索兄弟节点的叶子节点,并向上返回
                        # 如果距离大于等于当前最近距离
                        #    代表兄弟节点所构成区域,不存在比当前节点更近的节点,所以直接向上返回
                        dist_with_hyper = self._get_dist_with_hyper(target_x, current_node.father)
                        print(dist_with_hyper)
                        if dist > dist_with_hyper:
                            new_nearest_node = self._search_tree(target_x, bro_node)
                            print(f'new_nearest_node: {new_nearest_node}')
                            que.append((bro_node, new_nearest_node))
                            # 找到一个新的节点,以该节点开始重新递归
                            break
                    # 只要当前节点不是根节点,并且当前节点的兄弟节点分支不存在更近点,那就向上返回,一直到返回到父结点后,开始从下一个叶子节点开始向上递归
                    current_node = current_node.father
                    traversed_node.append(current_node)
                else:
                    break
        return k_nearest_node
                    
        

        
            
if __name__ == '__main__':
    # X = np.array([[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]]).T
    # y = np.array([1, 1, 0, 0, 1, 1])
    X = np.array([[6.27, 5.5], [1.24, -2.86], [17.05,-12.79], [-6.88, -5.4], [-2.96, -0.5], [7.75, -22.68],
                 [10.80, -5.03], [-4.6, -10.55], [-4.96, 12.61], [1.75, 12.26], [15.31, -13.16], 
                 [7.83, 15.70], [14.63, -0.35]]).T
    y = np.array([1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])
    tree = KDTree()
    tree.build_tree(X, y)
    tree._pre_order_traverse(tree.root)
    print(tree)
    nearest_node = tree.nearest_neighbor_search([3,4.5])
    print(nearest_node)
    k_nearest_node = tree.k_nearest_neighbor_search([-1, -5], 3)
    print([[x[0].val, x[1]] for x in k_nearest_node])
    
posted @   zoro-zhao  阅读(54)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Obsidian + DeepSeek:免费 AI 助力你的知识管理,让你的笔记飞起来!
· 分享4款.NET开源、免费、实用的商城系统
· 解决跨域问题的这6种方案,真香!
· 一套基于 Material Design 规范实现的 Blazor 和 Razor 通用组件库
· 5. Nginx 负载均衡配置案例(附有详细截图说明++)
点击右上角即可分享
微信分享提示