基于keras搭建cnn网络
本文参考了https://www.jianshu.com/p/132746857e3a,修改了其中不正确的地方,把1,28,28修改为28,28,1,即调整了通道数的位置,否则编译不通,采用的tensorflow版本是2.3.1。
# -*- coding: utf-8 -*-
#3. Import libraries and modules
import numpy as np
from keras.models import Sequential, save_model, load_model
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras.datasets import mnist
#4. Load pre-shuffled MNIST data into train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
print(X_train.shape)
#5.Preprocess input data
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
# 6. Preprocess class labels
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
# 7. Define model architecture
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
# 8. Compile model
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
# 9. Fit model on training data
model.fit(X_train, Y_train,
batch_size=32, epochs=10, verbose=1)
# 10. Save model to file
model.save("cnn_model.h5")
# 11. Load trianed model
model2=load_model("cnn_model.h5")
# 12. Evaluate model on test data
score = model2.evalute(X_test, Y_test, verbose=0)
训练结果如下:
Epoch 1/10
1875/1875 [==============================] - 49s 26ms/step - loss: 0.2027 - accuracy: 0.9388
Epoch 2/10
1875/1875 [==============================] - 53s 28ms/step - loss: 0.0843 - accuracy: 0.9745
Epoch 3/10
1875/1875 [==============================] - 51s 27ms/step - loss: 0.0654 - accuracy: 0.9805
Epoch 4/10
1875/1875 [==============================] - 52s 28ms/step - loss: 0.0546 - accuracy: 0.9827
Epoch 5/10
1875/1875 [==============================] - 54s 29ms/step - loss: 0.0455 - accuracy: 0.9858
Epoch 6/10
1875/1875 [==============================] - 54s 29ms/step - loss: 0.0397 - accuracy: 0.9877
Epoch 7/10
1875/1875 [==============================] - 54s 29ms/step - loss: 0.0371 - accuracy: 0.9883
Epoch 8/10
1875/1875 [==============================] - 53s 28ms/step - loss: 0.0317 - accuracy: 0.98961875 [==================>...........] - ETA: 18s - loss: 0.0312 - accuracy: 0.9895
Epoch 9/10
1875/1875 [==============================] - 55s 29ms/step - loss: 0.0297 - accuracy: 0.9905
Epoch 10/10
1875/1875 [==============================] - 51s 27ms/step - loss: 0.0284 - accuracy: 0.9912