K-近邻算法kNN

  K-近邻算法(k-Nearest Neighbor,简称kNN)采用测量不同特征值之间的距离方法进行分类,是一种常用的监督学习方法,其工作机制很简单:给定测试样本,基于某种距离量度找出训练集中与其靠近的k个训练样本,然后基于这k个“邻居”的信息进行预测。kNN算法属于懒惰学习,此类学习技术在训练阶段仅仅是把样本保存起来,训练时间靠小为零,在收到测试样本后在进行处理,所以可知kNN算法的缺点是计算复杂度高、空间复杂度高。但其也有优点,精度高、对异常值不敏感、无数据输入设定。

  借张图来说:

当k = 1时目标点有一个class2邻居,根据kNN算法的原理,目标点也为class2。

当k = 5时目标点有两个class2邻居,有三个class1的邻居,根据其原理,目标点的类别为class2。

算法流程

总体来说,KNN分类算法包括以下4个步骤:

①准备数据,对数据进行预处理 。

②计算测试样本点(也就是待分类点)到其他每个样本点的距离。

③对每个距离进行排序,然后选择出距离最小的K个点 。

④对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类 。

算法代码

package com.top.knn;

import com.top.constants.OrderEnum;
import com.top.matrix.Matrix;
import com.top.utils.MatrixUtil;

import java.util.*;


/**
 * @program: top-algorithm-set
 * @description: KNN k-临近算法进行分类
 * @author: Mr.Zhao
 * @create: 2020-10-13 22:03
 **/
public class KNN {
    public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception {
        if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) {
            throw new IllegalArgumentException("矩阵训练集与标签维度不一致");
        }
        if (input.getMatrixColCount() != dataSet.getMatrixColCount()) {
            throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致");
        }
        if (dataSet.getMatrixRowCount() < k) {
            throw new IllegalArgumentException("训练集样本数小于k");
        }
        // 归一化
        int trainCount = dataSet.getMatrixRowCount();
        int testCount = input.getMatrixRowCount();
        Matrix trainAndTest = dataSet.splice(2, input);
        Map<String, Object> normalize = MatrixUtil.normalize(trainAndTest, 0, 1);
        trainAndTest = (Matrix) normalize.get("res");
        dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount());
        input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount());

        // 获取标签信息
        List<Double> labelList = new ArrayList<>();
        for (int i = 0; i < labels.getMatrixRowCount(); i++) {
            if (!labelList.contains(labels.getValOfIdx(i, 0))) {
                labelList.add(labels.getValOfIdx(i, 0));
            }
        }

        Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]);
        for (int i = 0; i < input.getMatrixRowCount(); i++) {
            // 求向量间的欧式距离
            Matrix var1 = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount());
            Matrix var2 = dataSet.subtract(var1);
            Matrix var3 = var2.square();
            Matrix var4 = var3.sumRow();
            Matrix var5 = var4.pow(0.5);
            // 距离矩阵合并上labels矩阵
            Matrix var6 = var5.splice(1, labels);
            // 将计算出的距离矩阵按照距离升序排序
            var6.sort(0, OrderEnum.ASC);
            // 遍历最近的k个变量
            Map<Double, Integer> map = new HashMap<>();
            for (int j = 0; j < k; j++) {
                // 遍历标签种类数
                for (Double label : labelList) {
                    if (var6.getValOfIdx(j, 1) == label) {
                        map.put(label, map.getOrDefault(label, 0) + 1);
                    }
                }
            }
            result.setValue(i, 0, getKeyOfMaxValue(map));
        }
        return result;
    }

    /**
     * 取map中值最大的key
     *
     * @param map
     * @return
     */
    private static Double getKeyOfMaxValue(Map<Double, Integer> map) {
        if (map == null)
            return null;
        Double keyOfMaxValue = 0.0;
        Integer maxValue = 0;
        for (Double key : map.keySet()) {
            if (map.get(key) > maxValue) {
                keyOfMaxValue = key;
                maxValue = map.get(key);
            }
        }
        return keyOfMaxValue;
    }

}
KNN

注:其中的矩阵方法请参考https://github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/matrix/Matrix.java

  升降序枚举类参考https://github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/constants/OrderEnum.java

该算法为本人github项目中的一部分,地址为https://github.com/ineedahouse/top-algorithm-set

如果对你有帮助可以点个star~

参考

《机器学习》-周志华

《机器学习实战》-Peter Harrington

 
posted @ 2020-11-16 21:21  MrZhaoyx  阅读(347)  评论(0编辑  收藏  举报