《统计学习方法》第3章习题

习题3.1

习题3.2

根据例 3.2 构造的 kd 树,可知最近邻点为 \((2,3)^T\)

习题3.3

k 近邻法主要需要构造相应的 kd 树。这里用 Python 实现 kd 树的构造与搜索

import heapq
import numpy as np

class KDNode:
    def __init__(self, data, axis=0, left=None, right=None):
        self.data = data
        self.axis = axis
        self.left = left
        self.right = right

class KDTree:
    def __init__(self, data):
        self.raw_data = data
        self.k = data.shape[1]
    
    def construct(self):
        data = self.raw_data
        self.root = self._insert_node(data, 0)
    
    def search(self, x, near_k=1, p=2):
        self.knn = [(-np.inf, None)]*near_k
        self._visit(self.root, x, p)
        self.knn = np.array([i[1].data for i in heapq.nlargest(near_k, self.knn)])
        return self.knn
        
    def pre_order_traverse(self, node):
        print(node.data)
        if node.left:
            self.pre_order_traverse(node.left)
        if node.right:
            self.pre_order_traverse(node.right)
        
    def _insert_node(self, data, depth=0):
        if len(data) == 0:
            return None
        axis = depth % self.k
        data = sorted(data, key = lambda x: x[axis])
        middle = len(data) // 2
        return KDNode(
          data[middle], 
          axis, 
          self._insert_node(data[:middle], depth+1), 
          self._insert_node(data[middle+1:], depth+1)
        )
    
    def _visit(self, node, x, p=2):
        if node is not None:
            dis = x[node.axis] - node.data[node.axis]
            self._visit(node.left if dis < 0 else node.right, x, p)
            curr_dis = np.linalg.norm(x-node.data, p)
            heapq.heappushpop(self.knn, (-curr_dis, node))
            if -(self.knn[0][0]) > abs(dis):
                self._visit(node.right if dis < 0 else node.left, x, p)

if __name__ == "__main__":
    data = np.array([
        [2,3],
        [5,4],
        [9,6],
        [4,7],
        [8,1],
        [7,2]
    ])

    tree = KDTree(data)
    tree.construct()
    print(tree.search(np.array([3, 4.5]), 2))

通过调用 KDTree 的 search 方法即可实现查找 x 的 k 近邻。 结果为 \([(2,3)^T, (5,4)^T]\)

posted @ 2021-06-24 13:51  程劼  阅读(194)  评论(1编辑  收藏  举报