手写数字识别-使用TensorFlow构建和训练一个简单的神经网络

下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。

MNIST数据集的训练集有60000个样本:

Python代码

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import json
import os

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 预处理数据
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255

# 构建神经网络模型
def create_model():
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))

    # 编译模型
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# 训练模型并保存
def train_and_save_model():
    model = create_model()
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
    model.save('mnist_model.h5')

    # 保存训练历史记录
    with open('training_history.json', 'w') as f:
        json.dump(history.history, f)

# 加载模型和历史记录
def load_model_and_history():
    model = tf.keras.models.load_model('mnist_model.h5')

    with open('training_history.json', 'r') as f:
        history = json.load(f)
    
    return model, history

# 评估模型
def evaluate_model(model):
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print("Test accuracy: {}".format(test_acc))

# 可视化训练过程
def plot_training_history(history):
    plt.plot(history['accuracy'], label='accuracy')
    plt.plot(history['val_accuracy'], label='val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.show()

# 检查是否已经存在模型和历史记录
if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
    train_and_save_model()

model, training_history = load_model_and_history()
evaluate_model(model)
plot_training_history(training_history)

代码解释

  1. 加载MNIST数据集

    mnist = tf.keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    
  2. 预处理数据

    • 将图像数据调整为 (28, 28, 1) 的形状。
    • 将像素值标准化为 [0, 1] 之间。
    train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
    test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
    
  3. 构建神经网络模型

    • 使用 Sequential 模型,按顺序添加层。
    • 添加卷积层、池化层、全连接层。
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    
  4. 编译模型

    • 使用 adam 优化器,损失函数为 sparse_categorical_crossentropy,评估指标为 accuracy
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
  5. 训练模型

    • 训练模型5个epochs,并使用验证数据集评估模型性能。
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
    
  6. 评估模型

    • 在测试集上评估模型性能,并打印测试准确率。
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"Test accuracy: {test_acc}")
    
  7. 可视化训练过程

    • 绘制训练和验证准确率随epoch变化的曲线。
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.show()
    

通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。

posted @ 2024-07-05 14:41  def_Class  阅读(2)  评论(0编辑  收藏  举报