手写数字识别-小数据集
1.手写数字数据集
- from sklearn.datasets import load_digits
- digits = load_digits()
2.图片数据预处理
- x:归一化MinMaxScaler()
- y:独热编码OneHotEncoder()或to_categorical
- 训练集测试集划分
- 张量结构
3.设计卷积神经网络结构
- 绘制模型结构图,并说明设计依据。
4.模型训练
5.模型评价
- model.evaluate()
- 交叉表与交叉矩阵
- pandas.crosstab
- seaborn.heatmap
实现代码
# # author:陌攻 import numpy from tensorflow.keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense from sklearn.externals import joblib from keras.utils import np_utils import numpy as np import pandas as pd import seaborn as sns import struct seed = 7 numpy.random.seed(seed) # 加载数据 (X_tarin, y_train), (X_test, y_test) = mnist.load_data() y_test1=y_test # 数据处理 # # 数据降维与转码 num_pixels = X_tarin.shape[1] * X_tarin.shape[2] X_tarin = X_tarin.reshape(X_tarin.shape[0], num_pixels).astype('float32') X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32') # # 像素255*255*255 X_tarin = X_tarin / 255 X_test = X_test / 255 # # 对输出进行one hot编码 y_train = np_utils.to_categorical(y_train) y_test = np_utils.to_categorical(y_test) num_classes = y_test.shape[1] # MLP模型 def baseline_model(): model = Sequential() model.add(Dense(num_pixels, input_dim=num_pixels, init='normal', activation='relu')) model.add(Dense(num_classes, init='normal', activation='softmax')) model.summary() model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) return model # 建立模型 model = baseline_model() # 训练模型 model.fit(X_tarin, y_train, validation_data=(X_test, y_test), nb_epoch=10, batch_size=200, verbose=2) # 保存模型 joblib.dump(model, 'NumberModel.pkl') # 读取模型 # model = joblib.load('NumberModel.pkl') # 模型评估 scores = model.evaluate(X_test, y_test, verbose=0) print("正确率: %.2f%%" % (scores[1]*100)) # 输出正确率 # 交叉表与交叉矩阵 # # 识别test数据 y_pred=model.predict(X_test) # # 将识别出来的数组(10000,10)还原成数字(10000,) y_pred=np.argmax(y_pred,axis=1).reshape(-1) a=pd.crosstab(np.array(y_test1),y_pred) # # 属性转换dataframe df=pd.DataFrame(a) # # 打印交叉矩阵 print(df) # # 绘制交叉表 from matplotlib import pyplot as plt sns.heatmap(df,annot=True,cmap="YlGnBu",linewidths=0.2,linecolor='G') plt.show()
运行结果图:
交叉矩阵
交叉表