模型加载与保存

TensorFlow2教程完整教程目录(更有python、go、pytorch、tensorflow、爬虫、人工智能教学等着你):https://www.cnblogs.com/nickchen121/p/10840284.html

Outline

  • save/load weights # 记录部分信息
  • save/load entire model # 记录所有信息
  • saved_model # 通用,包括Pytorch、其他语言

Save/load weights

  • 保存部分信息
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss, acc = model.evaluate(test_images, test_labels)
print(f'Restored model, accuracy: {100*acc:5.2f}')
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics


def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10)
])
network.build(input_shape=(None, 28 * 28))
network.summary()

network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)

network.evaluate(ds_val)

network.save_weights('weights.ckpt')
print('saved weights.')
del network

network = Sequential([
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10)
])
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.load_weights('weights.ckpt')
print('loaded weights!')
network.evaluate(ds_val)
datasets: (60000, 28, 28) (60000,) 0 255
(128, 784) (128, 10)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  200960    
_________________________________________________________________
dense_1 (Dense)              multiple                  32896     
_________________________________________________________________
dense_2 (Dense)              multiple                  8256      
_________________________________________________________________
dense_3 (Dense)              multiple                  2080      
_________________________________________________________________
dense_4 (Dense)              multiple                  330       
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
Epoch 1/3
469/469 [==============================] - 5s 12ms/step - loss: 0.2876 - accuracy: 0.8335
Epoch 2/3
469/469 [==============================] - 5s 11ms/step - loss: 0.1430 - accuracy: 0.9551 - val_loss: 0.1397 - val_accuracy: 0.9634
Epoch 3/3
469/469 [==============================] - 4s 9ms/step - loss: 0.1155 - accuracy: 0.9681
79/79 [==============================] - 1s 8ms/step - loss: 0.1344 - accuracy: 0.9654
saved weights.
loaded weights!
79/79 [==============================] - 1s 13ms/step - loss: 0.1344 - accuracy: 0.9593





[0.13439734456132318, 0.9654]

Save/load entire model

  • 完美保存所有信息
network.save('model.h5')
print('saved total model.')
del network

print('load model from file')
network = tf.keras.models.load_model('model.h5')

network.evaluate(x_val, y_val)
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics


def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10)
])
network.build(input_shape=(None, 28 * 28))
network.summary()

network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)

network.evaluate(ds_val)

network.save('model.h5')
print('saved total model.')
del network

print('load model from file')

network1 = tf.keras.models.load_model('model.h5')
network1.compile(optimizer=optimizers.Adam(lr=0.01),
                 loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])
x_val = tf.cast(x_val, dtype=tf.float32) / 255.
x_val = tf.reshape(x_val, [-1, 28 * 28])
y_val = tf.cast(y_val, dtype=tf.int32)
y_val = tf.one_hot(y_val, depth=10)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(128)
network1.evaluate(ds_val)
datasets: (60000, 28, 28) (60000,) 0 255
(128, 784) (128, 10)
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_20 (Dense)             multiple                  200960    
_________________________________________________________________
dense_21 (Dense)             multiple                  32896     
_________________________________________________________________
dense_22 (Dense)             multiple                  8256      
_________________________________________________________________
dense_23 (Dense)             multiple                  2080      
_________________________________________________________________
dense_24 (Dense)             multiple                  330       
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
Epoch 1/3
469/469 [==============================] - 6s 13ms/step - loss: 0.2851 - accuracy: 0.8405
Epoch 2/3
469/469 [==============================] - 6s 13ms/step - loss: 0.1365 - accuracy: 0.9580 - val_loss: 0.1422 - val_accuracy: 0.9590
Epoch 3/3
469/469 [==============================] - 5s 11ms/step - loss: 0.1130 - accuracy: 0.9661
79/79 [==============================] - 1s 10ms/step - loss: 0.1201 - accuracy: 0.9714
saved total model.
load model from file


W0525 16:44:50.178785 4587234752 hdf5_format.py:266] Sequential models without an `input_shape` passed to the first layer cannot reload their optimizer state. As a result, your model isstarting with a freshly initialized optimizer.


79/79 [==============================] - 1s 7ms/step - loss: 0.1201 - accuracy: 0.9672





[0.12005392337660747, 0.9714]

saved_model

  • 通用,包括Pytorch、其他语言
  • 用于工业环境的部署
tf.saved_model.save(m, '/tmp/saved_model/')

imported = tf.saved_model.load(path)
f = imported.signatures['serving_default']
print(f(x=tf.ones([1, 28, 28, 3])))
posted @ 2019-05-25 16:47  B站-水论文的程序猿  阅读(777)  评论(0编辑  收藏  举报