LeNet5实现CIFAR-10图片分类

import tensorflow as tf
from tensorflow.keras import datasets ,layers ,models
import matplotlib.pyplot as plt
from keras import regularizers
# load and normalize the data
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
num_classes = 10
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255
# LeNet5
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), padding='valid', activation=tf.nn.relu,
                           input_shape=(32, 32, 3)),
    tf.keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='same'),
    tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), padding='valid', activation=tf.nn.relu,
                           input_shape=(32, 32, 3)),
    tf.keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='same'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=16, kernel_regularizer=regularizers.l2(0.001),activation=tf.nn.relu,input_shape=(10000,)),
    tf.keras.layers.Dense(units=16, activation=tf.nn.relu,kernel_regularizer=regularizers.l2(0.001),),
    tf.keras.layers.Dense(units=10, activation=tf.nn.sigmoid),
#     tf.keras.layers.Dense(16, kernel_regularizer=regularizers.l2(0.001),
# activation=tf.nn.relu, input_shape=(10000,)),
#     tf.keras.layers.Dense(16, kernel_regularizer=regularizers.l2(0.001),
# activation=tf.nn.relu),
#     tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
model.summary()

# train the model using ADAM
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
# fit
history=model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)

# 训练结果可视化
loss = history.history["loss"]
val_loss = history.history["val_loss"]
acc = history.history["sparse_categorical_accuracy"]
val_acc = history.history["val_sparse_categorical_accuracy"]
plt.subplot(1,2,1)
plt.plot(loss,label = "Training Loss")
plt.plot(val_loss,label = "Validation Loss")
plt.title("Trainning and Validation Loss")
plt.legend()
plt.subplot(1,2,2)
plt.plot(acc,label = "Training Acc")
plt.plot(val_acc,label = "Validation Acc")
plt.title("Training and Validation Acc")
plt.legend()
# evaluate
model.evaluate(x_test, y_test,verbose=2)

 

posted @ 2022-11-07 21:56  山海自有归期  阅读(42)  评论(0编辑  收藏  举报