educoder 机器学习 --- kNN算法

第一关:

复制代码
#encoding=utf8
import numpy as np

from collections import Counter

class kNNClassifier(object):
    def __init__(self, k):
        '''
        初始化函数
        :param k:kNN算法中的k
        '''
        self.k = k
        # 用来存放训练数据,类型为ndarray
        self.train_feature = None
        # 用来存放训练标签,类型为ndarray
        self.train_label = None


    def fit(self, feature, label):
        '''
        kNN算法的训练过程
        :param feature: 训练集数据,类型为ndarray
        :param label: 训练集标签,类型为ndarray
        :return: 无返回
        '''

        #********* Begin *********#
        self.train_feature = feature
        self.train_label = label
        #********* End *********#


    def predict(self, feature):
        '''
        kNN算法的预测过程
        :param feature: 测试集数据,类型为ndarray
        :return: 预测结果,类型为ndarray或list
        '''

        #********* Begin *********#
        result = []
        for data in feature:
            dist = np.sqrt(np.sum((self.train_feature - data) ** 2, axis = 1)) # 欧氏距离
            neighbor = np.argsort(dist)[0 : self.k]
            kLabel = (self.train_label[i] for i in neighbor)
            key, value = Counter(kLabel).most_common(1)[0] # 如果k个邻居中出现次数最多的label不止一个,要取总距离最小的label,这里直接取第一个(懒得写了
            result.append(key)
        return result
        #********* End *********#
复制代码

第2关:

复制代码
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler

def classification(train_feature, train_label, test_feature):
    '''
    对test_feature进行红酒分类
    :param train_feature: 训练集数据,类型为ndarray
    :param train_label: 训练集标签,类型为ndarray
    :param test_feature: 测试集数据,类型为ndarray
    :return: 测试集数据的分类结果
    '''

    #********* Begin *********#
    #实例化StandardScaler函数
    scaler = StandardScaler()
    train_feature = scaler.fit_transform(train_feature)
    test_feature = scaler.transform(test_feature)
   
    #生成K近邻分类器
    clf = KNeighborsClassifier()
    #训练分类器
    clf.fit(train_feature, train_label)
    #进行预测
    predict_result = clf.predict(test_feature)
    return predict_result 
    #********* End **********#
复制代码

 

posted @   kafuuchino  阅读(20)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
点击右上角即可分享
微信分享提示