机器学习——KNN

导入类库

 1 import numpy as np
 2 from sklearn.neighbors import KNeighborsClassifier
 3 from sklearn.model_selection import train_test_split
 4 from sklearn.preprocessing import StandardScaler
 5 from sklearn.linear_model import LinearRegression
 6 from sklearn.metrics import r2_score
 7 from sklearn.datasets import load_iris
 8 import matplotlib.pyplot as plt
 9 import pandas as pd
10 import seaborn as sns
# 熵增益
# 熵越大,信息量越大,蕴含的不确定性越大
KNN
1.计算待预测值到所有点的距离
2.对所有距离排序
3.找出前K个样本里面类别最多的类,作为待预测值的类别

代码

 1 A = np.array([[1, 1], [1, 1.5], [0.5, 1.5]])
 2 B = np.array([[3.0, 3.0], [3.0, 3.5], [2.8, 3.1]])
 3 
 4 
 5 def knn_pre_norm(point):
 6     a_len = np.linalg.norm(point - A, axis=1)
 7     b_len = np.linalg.norm(point - B, axis=1)
 8     print(a_len.min())
 9     print(b_len.min())
10 
11 
12 def knn_predict_rev(point):
13     X = np.array([[1, 1], [1, 1.5], [0.5, 1.5], [3.0, 3.0], [3.0, 3.5], [2.8, 3.1]])
14     Y = np.array([0, 0, 0, 1, 1, 1])
15 
16     knn = KNeighborsClassifier(n_neighbors=2)
17     knn.fit(X, Y)
18 
19     print(knn.predict(np.array([[1.0, 3.0]])))
20 
21 
22 def iris_linear():
23     # 加载iris数据
24     li = load_iris()
25     # 散点图
26     # plt.scatter(li.data[:, 0], li.data[:, 1], c=li.target)
27     # plt.scatter(li.data[:, 2], li.data[:, 3], c=li.target)
28     # plt.show()
29     # 分割测试集和训练集,测试集占整个数据集的比例是0.25
30     x_train, x_test, y_train, y_test = train_test_split(li.data, li.target, test_size=0.25)
31     # 创建KNN分类,使用最少5个邻居作为类别判断标准
32     knn = KNeighborsClassifier(n_neighbors=5)
33     # 训练数据
34     knn.fit(x_train, y_train)
35     # 预测测试集
36     # print(knn.predict(x_test))
37     # 预测np.array([[6.3, 3, 5.2, 2.3]])
38     print(knn.predict(np.array([[6.3, 3, 5.2, 2.3]])))
39     # 预测np.array([[6.3, 3, 5.2, 2.3]])所属各个类别的概率
40     print(knn.predict_proba(np.array([[6.3, 3, 5.2, 2.3]])))
41 
42 
43 if __name__ == '__main__':
44     # knn_predict_rev(None)
45     # knn_pre_norm(np.array([2.3,2.3]))
46     iris_linear()

 

posted @ 2018-10-05 09:43  BO00097  阅读(276)  评论(0编辑  收藏  举报