TensorFlow学习笔记--Mnist全连接模型实践

import os
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from tensorflow.python.keras import Model
from tensorflow.python.keras.layers import Flatten, Dense

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train/255.0, x_test/255.0

checkpoint_save_path = './checkpoint/model.ckpt'


# 搭建模型类
class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.flatten = Flatten()
        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.flatten(x)
        x = self.dense1(x)
        y = self.dense2(x)
        return y


model = MnistModel()

# 模型优化
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['sparse_categorical_accuracy'])

# callback保存模型
model_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True,
                                                    save_best_only=True)

# 曾经保存过,直接加载权重参数
if os.path.exists(checkpoint_save_path + '.index'):
    model.load_weights(checkpoint_save_path)

# 开始训练
model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback])

# 结果总览
model.summary()

# 保存模型参数到文本,方便查看
# with open('./weight.txt', 'w') as f:
#     for i in model.trainable_variables:
#         f.write(str(i.name) + '\n')
#         f.write(str(i.shape) + '\n')
#         # f.write(str(i.numpy()) + '\n')   # 这行有问题

 

posted @ 2021-12-10 16:44  一朵包纸  阅读(80)  评论(0编辑  收藏  举报