机器学习之KNN算法

本次随笔包括调用sklearn库函数和手撸KNN算法,代码如下:

算法步骤:

1、为了判断未知实例类别,以已知实例为参考

2、选择参数k(最近的k个已知实例)

3、计算未知实例与所有已知实例的距离

4、根据少数服从多数的方法,让未知实例归类为k个邻近样本中最多的类别

 

调用库函数:

 1 from sklearn import neighbors
 2 from sklearn import datasets
 3 
 4 knn = neighbors.KNeighborsClassifier()
 5 iris = datasets.load_iris()      # 调用库里面的数据
 6 
 7 knn.fit(iris.data,iris.target)
 8 
 9 predictLable = knn.predict([[5,3,5,2]])
10 print(predictLable)

KNN算法实现:

 1 import random
 2 import math
 3 import operator
 4 
 5 def loadData(filename,split,trainingSet=[],testSet=[]):
 6     with open(filename,'r') as file:
 7         lines = file.readlines()
 8         dataset = [[] for i in range(len(lines)-1)]
 9         for i in range(len(dataset)):
10             dataset[i][:] = (item for item in lines[i].strip().split(','))   # 逐行读取数据
11         for x in range(len(dataset)):
12             for y in range(len(dataset[0])-1):
13                 dataset[x][y] = float(dataset[x][y])          # 将除最后一列的数据转化为浮点型
14             if random.random() < split:                # 将数据集进行划分
15                 trainingSet.append(dataset[x])
16             else:
17                 testSet.append(dataset[x])
18     print("trainingSet",trainingSet)
19     print("testset",testSet)
20     return trainingSet,testSet
21 
22 
23 def conclu_distance(a,b,length):       # 计算两个向量的欧氏距离
24     dis = 0
25     for i in range(length):
26         dis += pow(a[i]-b[i],2)
27     return math.sqrt(dis)
28 
29 
30 def getNeibors(trainingSet,testinstance,k):   # 参数分别为:训练集,待计算邻居的向量,邻居数量k。  获取k个最近邻居的类别值
31     distance = []
32     length = (len(testinstance)-1)
33     for x in range(len(trainingSet)):
34         dist = conclu_distance(testinstance,trainingSet[x],length)      # 注意这里数据的添加格式
35         distance.append((trainingSet[x], dist))
36     distance.sort(key=operator.itemgetter(1))
37     neighbers = []
38     for x in range(k):
39         neighbers.append(distance[x][0])
40     return neighbers
41 
42 
43 def getResponse(neighbors):   # 从得到的邻居中计算得到测试结果
44     classVotes = {}          # 投票通常采用字典形式
45     for x in range(len(neighbors)):
46         response = neighbors[x][-1]
47         if response in classVotes:
48             classVotes[response]+=1
49         else:
50             classVotes[response] = 1
51     sortedVotes = sorted(classVotes,key=operator.itemgetter(1),reverse=True)
52     return sortedVotes[0]
53 
54 
55 def getAccuracy(testSet,predictions):      # 计算测试值准确率
56     correct = 0
57     for x in range(len(testSet)):
58         if testSet[x][-1] == predictions[x]:
59             correct += 1
60     return (correct/float(len(testSet)))*100
61 
62 def main():
63     trainingSet = []
64     testSet = []
65     split = 0.67
66     loadData('irisdata.txt',split,trainingSet,testSet)
67     predictions = []
68     k =3
69     for x in range(len(testSet)):
70         neighbers = getNeibors(trainingSet,testSet[x],k)
71         result = getResponse(neighbers)
72         predictions.append(result)
73         print('>predicted',repr(result),',actual=',repr(testSet[x][-1]))
74     accuracy = getAccuracy(testSet,predictions)
75     print('准确率为:',repr(accuracy),'%')
76 
77 main()

 

posted @ 2018-09-10 11:06  去冰七分糖  阅读(307)  评论(0编辑  收藏  举报