机器学习十五----手写数字识别-小数据集
1.手写数字数据集
- from sklearn.datasets import load_digits
- digits = load_digits()
2.图片数据预处理
- x:归一化MinMaxScaler()
- y:独热编码OneHotEncoder()或to_categorical
- 训练集测试集划分
- 张量结构
# 2、数据预处理 scaler = MinMaxScaler() X_data = scaler.fit_transform(X_data) # X归一化 print("归一化后的数据:") print(X_data) # Y独热编码 Y = OneHotEncoder().fit_transform(Y_target).todense() print("独热编码后的标签数据:") print(Y) #转换为图片格式 X = X_data.reshape(-1,8,8,1) print(X) #划分训练集测试集 X_train,X_test,Y_train,Y_test = train_test_split(X ,Y ,test_size=0.2,random_state=0,stratify=Y)
3.设计卷积神经网络结构
- 绘制模型结构图,并说明设计依据。
# 3、设计卷积神经网络 model = Sequential() # 一层卷积 model.add(Conv2D(filters=16,kernel_size=(5, 5),padding='same',input_shape=X_train.shape[1:],activation='relu')) # 二层卷积 model.add(Conv2D(filters=32,kernel_size=(5, 5),padding='same',activation='relu')) # 池化层1 model.add(MaxPool2D(pool_size=(2, 2))) model.add(Dropout(0.25)) #三层卷积 model.add(Conv2D(filters=32,kernel_size=(5, 5),padding='same',activation='relu')) # 池化层2 model.add(MaxPool2D(pool_size=(2, 2))) model.add(Dropout(0.25)) # 四层卷积 model.add(Conv2D(filters=64,kernel_size=(5, 5),padding='same',activation='relu')) model.add(Flatten()) # 全连接层 model.add(Dense(128, activation='relu')) model.add(Dropout(0.25)) model.add(Dense(10, activation='softmax')) # 激活函数 model.summary()
4.模型训练
- model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
- train_history = model.fit(x=X_train,y=y_train,validation_split=0.2, batch_size=300,epochs=10,verbose=2)
# 4、模型训练 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) train_history = model.fit(x=X_train,y=Y_train,validation_split=0.2, batch_size=300,epochs=10,verbose=2)
5.模型评价
# 5、模型评价 score = model.evaluate(X_test,Y_test) print(score) # 交叉表和交叉矩阵 # 使用模型对X的测试数据进行预测 y_pre = model.predict_classes(X_test) print(y_pre) # 交叉表 y_test = np.argmax(Y_test,axis=1).reshape(-1) y_true = np.array(y_test)[0] print(y_true) pd.crosstab(y_true,y_pre,rownames=["true"],colnames=["predict"]) # 交叉矩阵 a = pd.crosstab(y_true,y_pre) df = pd.DataFrame(a) sns.heatmap(df,annot=True,cmap='summer',linewidths=0.2,linecolor='R') plt.show()