机器学习之K近邻算法实现

import operator

from numpy import array, tile


def create_dataset():
    _dataset = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    _labels = ['A', 'A', 'B', 'B']
    return _dataset, _labels


def classify(x: list, dataset: array, labels, k):
    """
    训练步骤:
          (1)计算输入点与样本数据点的距离
          (2)按照距离排序
          (3)获取距离最小的前k个点
          (4)确定前k个点在类别中出现的频率
          (5)返回前k个点出现频率最高的类别作为输入点的预测分类
    :param x: 用于训练的的数据,两个元素的列表
    :param dataset: 样本数据集
    :param labels: 标签向量
    :param k: 最近邻居数
    :return:
    """
    # 获取数据形状
    dataset_size = dataset.shape[0]
    # 求差
    diff_mat = tile(x, (dataset_size, 1)) - dataset
    # 求平方差
    sq_diff_mat = diff_mat ** 2
    # 求平方差的和
    sq_distance = sq_diff_mat.sum(axis=1)
    # 求距离
    distances = sq_distance ** 0.5
    # 排序
    sorted_distances = distances.argsort()
    # 分类统计,用于计算前k个最近距离
    class_count = {}
    for i in range(k):
        # 获取标签
        vote_label = labels[sorted_distances[i]]
        # 累加统计标签个数
        class_count[vote_label] = class_count.get(vote_label, 0) + 1
    # 对标签出现的次数排序
    sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]


if __name__ == '__main__':
    dataset, labels = create_dataset()
    print(classify([0, 0], dataset, labels, 3))

其他knn示例或者基于主流机器学习框架实现的knn代码地址:

https://gitee.com/navysummer/machine-learning/tree/master/knn

  

posted @ 2024-06-08 14:26  NAVYSUMMER  阅读(7)  评论(0编辑  收藏  举报
交流群 编程书籍