import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
%matplotlib inline
path='./1.KNN/1-KNN/exercise/data/1/1_2.bmp'
example_img=plt.imread(path)
plt.imshow(example_img)
<matplotlib.image.AxesImage at 0x9a48dd8>
#批量获取数据
data = []
target = []
for i in range(0,10):
for j in range(1,501):
img_arr=plt.imread('./1.KNN/1-KNN/exercise/data/%s/%s_%s.bmp'%(str(i),str(i),str(j)))
data.append(img_arr)
target.append(i)
len(data)
5000
len(target)
5000
#将机器学习的数据转换为ndarray,操作起来比较方便
np_data=np.array(data)
np_target=np.array(target)
print(np_data.shape,np_target.shape)
(5000, 28, 28) (5000,)
sd=np.random.randint(1,5000,size=1)[0]
print(sd)
np.random.seed(sd)
np.random.shuffle(np_data)
np.random.seed(sd)
np.random.shuffle(np_target)
3892
x_train,y_train=np_data[0:4950],np_target[0:4950]
x_test,y_test=np_data[4950:],np_target[4950:]
x_train = x_train.reshape(4950,784)
x_test=x_test.reshape(50,784)
knn = KNeighborsClassifier(n_neighbors=15)
knn.fit(x_train,y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=15, p=2,
weights='uniform')
y=knn.predict(x_test.reshape(50,784))
print('真实:',y_test)
print('预测:',y)
真实: [3 8 1 9 8 8 5 3 0 8 1 2 8 7 0 1 0 8 5 1 1 4 9 0 0 8 8 1 2 7 2 9 3 3 7 5 9
7 5 8 0 9 3 6 7 1 9 3 2 1]
预测: [3 8 1 9 8 5 5 3 0 5 1 2 8 7 0 1 0 3 5 1 1 4 9 0 0 8 8 1 2 7 2 9 3 3 7 5 9
7 5 8 0 9 3 6 7 1 9 3 2 1]
knn.score(x_train,y_train)
0.93595959595959599
knn.score(x_test,y_test)
0.93999999999999995
# digit=plt.imread('./1.KNN/1-KNN/exercise/数字1.png')
digit=plt.imread('./1.KNN/1-KNN/exercise/数字.jpg')
digit=digit.mean(axis=2)
plt.imshow(digit)
<matplotlib.image.AxesImage at 0x1343c588>
# new_digit=digit[80:160,65:145]
new_digit=digit[80:170,0:65]
plt.imshow(new_digit)
<matplotlib.image.AxesImage at 0x147139e8>
a_x,a_y=new_digit.shape
print(a_x,a_y,type(a_x))
90 65 <class 'int'>
import scipy.ndimage as ndimage
data_pre_test=ndimage.zoom(new_digit,zoom=(28/a_x,28/a_y))
plt.figure(figsize=(2,2))
plt.imshow(data_pre_test)
<matplotlib.image.AxesImage at 0x14771a90>
knn.predict(data_pre_test.reshape(1,-1))
array([4])
#保存模型
#knn 模型,算法
from sklearn.externals import joblib
joblib.dump(knn,'number识别.m')
['number识别.m']
knn2=joblib.load('number识别.m')
knn2.predict(data_pre_test.reshape(1,-1))[0]
4