……

TensorFlow2教程-Keras概述

Keras 是一个用于构建和训练深度学习模型的高阶 API。它可用于快速设计原型、高级研究和生产。

Keras的3个优点: 方便用户使用、模块化和可组合、易于扩展

1 导入tf.keras

TensorFlow2推荐使用tf.keras构建网络,常见的神经网络都包含在tf.keras.layer中(最新的tf.keras的版本可能和keras不同)

In [1]:
import tensorflow as tf
from tensorflow.keras import layers
print(tf.__version__)
print(tf.keras.__version__)
 
/home/doit/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
 
2.0.0
2.2.4-tf
 

2 构建简单模型

2.1 模型堆叠

最常见的模型类型是层的堆叠:tf.keras.Sequential 模型

In [2]:
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
 

2.2 网络配置

tf.keras.layers中主要的网络配置参数如下:

activation:设置层的激活函数。此参数可以是函数名称字符串,也可以是函数对象。默认情况下,系统不会应用任何激活函数。

kernel_initializer 和 bias_initializer:创建层权重(核和偏置)的初始化方案。此参数是一个名称或可调用的函数对象,默认为 "Glorot uniform" 初始化器。

kernel_regularizer 和 bias_regularizer:应用层权重(核和偏置)的正则化方案,例如 L1 或 L2 正则化。默认情况下,系统不会应用正则化函数。

In [3]:
layers.Dense(32, activation='sigmoid')
layers.Dense(32, activation=tf.sigmoid)
layers.Dense(32, kernel_initializer='orthogonal')
layers.Dense(32, kernel_initializer=tf.keras.initializers.glorot_normal)
layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(0.01))
layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l1(0.01))
Out[3]:
<tensorflow.python.keras.layers.core.Dense at 0x7f0d70476518>
 

3 训练和评估

3.1 设置训练流程

构建好模型后,通过调用 compile 方法配置该模型的学习流程:

In [4]:
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss=tf.keras.losses.categorical_crossentropy,
             metrics=[tf.keras.metrics.categorical_accuracy])
 

3.2 输入Numpy数据

对于小型数据集,可以使用Numpy构建输入数据。

In [5]:
import numpy as np

train_x = np.random.random((1000, 72))
train_y = np.random.random((1000, 10))

val_x = np.random.random((200, 72))
val_y = np.random.random((200, 10))

model.fit(train_x, train_y, epochs=10, batch_size=100,
          validation_data=(val_x, val_y))
 
Train on 1000 samples, validate on 200 samples
Epoch 1/10
1000/1000 [==============================] - 1s 503us/sample - loss: 11.8979 - categorical_accuracy: 0.0920 - val_loss: 12.1516 - val_categorical_accuracy: 0.1550
Epoch 2/10
1000/1000 [==============================] - 0s 21us/sample - loss: 12.2874 - categorical_accuracy: 0.0910 - val_loss: 12.9158 - val_categorical_accuracy: 0.1150
Epoch 3/10
1000/1000 [==============================] - 0s 31us/sample - loss: 13.3758 - categorical_accuracy: 0.0940 - val_loss: 14.4959 - val_categorical_accuracy: 0.0900
Epoch 4/10
1000/1000 [==============================] - 0s 28us/sample - loss: 15.4028 - categorical_accuracy: 0.0920 - val_loss: 17.2651 - val_categorical_accuracy: 0.1000
Epoch 5/10
1000/1000 [==============================] - 0s 24us/sample - loss: 18.6433 - categorical_accuracy: 0.0930 - val_loss: 21.0552 - val_categorical_accuracy: 0.1000
Epoch 6/10
1000/1000 [==============================] - 0s 31us/sample - loss: 22.0638 - categorical_accuracy: 0.0930 - val_loss: 23.7690 - val_categorical_accuracy: 0.1000
Epoch 7/10
1000/1000 [==============================] - 0s 25us/sample - loss: 24.8034 - categorical_accuracy: 0.0930 - val_loss: 27.7582 - val_categorical_accuracy: 0.1000
Epoch 8/10
1000/1000 [==============================] - 0s 24us/sample - loss: 30.0158 - categorical_accuracy: 0.0920 - val_loss: 34.5013 - val_categorical_accuracy: 0.1000
Epoch 9/10
1000/1000 [==============================] - 0s 27us/sample - loss: 37.2771 - categorical_accuracy: 0.0920 - val_loss: 42.4729 - val_categorical_accuracy: 0.1000
Epoch 10/10
1000/1000 [==============================] - 0s 30us/sample - loss: 45.5502 - categorical_accuracy: 0.0930 - val_loss: 51.9351 - val_categorical_accuracy: 0.1000
Out[5]:
<tensorflow.python.keras.callbacks.History at 0x7f0d702641d0>
 

3.3 tf.data输入数据

对于大型数据集可以使用tf.data构建训练输入。

In [6]:
dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
dataset = dataset.batch(32)
dataset = dataset.repeat()
val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y))
val_dataset = val_dataset.batch(32)
val_dataset = val_dataset.repeat()

model.fit(dataset, epochs=10, steps_per_epoch=30,
          validation_data=val_dataset, validation_steps=3)
 
Train for 30 steps, validate for 3 steps
Epoch 1/10
30/30 [==============================] - 0s 12ms/step - loss: 67.4564 - categorical_accuracy: 0.0948 - val_loss: 87.7801 - val_categorical_accuracy: 0.0938
Epoch 2/10
30/30 [==============================] - 0s 2ms/step - loss: 110.4207 - categorical_accuracy: 0.0983 - val_loss: 137.2176 - val_categorical_accuracy: 0.0729
Epoch 3/10
30/30 [==============================] - 0s 2ms/step - loss: 166.3288 - categorical_accuracy: 0.1026 - val_loss: 200.0635 - val_categorical_accuracy: 0.0729
Epoch 4/10
30/30 [==============================] - 0s 2ms/step - loss: 234.6779 - categorical_accuracy: 0.0929 - val_loss: 276.1790 - val_categorical_accuracy: 0.0938
Epoch 5/10
30/30 [==============================] - 0s 1ms/step - loss: 316.2306 - categorical_accuracy: 0.0855 - val_loss: 362.6940 - val_categorical_accuracy: 0.0938
Epoch 6/10
30/30 [==============================] - 0s 2ms/step - loss: 405.1462 - categorical_accuracy: 0.0962 - val_loss: 454.0105 - val_categorical_accuracy: 0.0938
Epoch 7/10
30/30 [==============================] - 0s 3ms/step - loss: 495.8588 - categorical_accuracy: 0.0897 - val_loss: 542.2636 - val_categorical_accuracy: 0.1042
Epoch 8/10
30/30 [==============================] - 0s 2ms/step - loss: 589.8635 - categorical_accuracy: 0.1132 - val_loss: 637.4122 - val_categorical_accuracy: 0.0833
Epoch 9/10
30/30 [==============================] - 0s 2ms/step - loss: 679.3736 - categorical_accuracy: 0.1079 - val_loss: 724.9229 - val_categorical_accuracy: 0.1146
Epoch 10/10
30/30 [==============================] - 0s 1ms/step - loss: 757.8416 - categorical_accuracy: 0.1004 - val_loss: 787.8435 - val_categorical_accuracy: 0.0938
Out[6]:
<tensorflow.python.keras.callbacks.History at 0x7f0d6841fc88>
 

3.4 评估与预测

评估和预测函数:tf.keras.Model.evaluate和tf.keras.Model.predict方法,都可以可以使用NumPy和tf.data.Dataset构造的输入数据进行评估和预测

In [7]:
# 模型评估
test_x = np.random.random((1000, 72))
test_y = np.random.random((1000, 10))
model.evaluate(test_x, test_y, batch_size=32)
test_data = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_data = test_data.batch(32).repeat()
model.evaluate(test_data, steps=30)
 
1000/1 [================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 26us/sample - loss: 820.5671 - categorical_accuracy: 0.1000
30/30 [==============================] - 0s 1ms/step - loss: 786.1435 - categorical_accuracy: 0.0969
Out[7]:
[786.1434997558594, 0.096875]
In [8]:
# 模型预测
result = model.predict(test_x, batch_size=32)
print(result)
 
[[0.245288   0.00419576 0.         ... 0.11711721 0.         0.        ]
 [0.29617038 0.00480588 0.         ... 0.14704514 0.         0.        ]
 [0.15192655 0.00397816 0.         ... 0.1538676  0.         0.        ]
 ...
 [0.23659316 0.00576886 0.         ... 0.13454369 0.         0.        ]
 [0.31788164 0.0062953  0.         ... 0.14174958 0.         0.        ]
 [0.40483308 0.01813794 0.         ... 0.15039128 0.         0.        ]]
 

4 构建复杂模型

4.1 函数式API

tf.keras.Sequential 模型是层的简单堆叠,无法表示任意模型。使用 Keras的函数式API可以构建复杂的模型拓扑,例如:

  • 多输入模型,

  • 多输出模型,

  • 具有共享层的模型(同一层被调用多次),

  • 具有非序列数据流的模型(例如,残差连接)。

使用函数式 API 构建的模型具有以下特征:

  • 层实例可调用并返回张量。
  • 输入张量和输出张量用于定义 tf.keras.Model 实例。
  • 此模型的训练方式和 Sequential 模型一样。
In [9]:
input_x = tf.keras.Input(shape=(72,))
hidden1 = layers.Dense(32, activation='relu')(input_x)
hidden2 = layers.Dense(16, activation='relu')(hidden1)
pred = layers.Dense(10, activation='softmax')(hidden2)
# 构建tf.keras.Model实例
model = tf.keras.Model(inputs=input_x, outputs=pred)
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss=tf.keras.losses.categorical_crossentropy,
             metrics=['accuracy'])
model.fit(train_x, train_y, batch_size=32, epochs=5)
 
Train on 1000 samples
Epoch 1/5
1000/1000 [==============================] - 0s 351us/sample - loss: 13.1064 - accuracy: 0.1080
Epoch 2/5
1000/1000 [==============================] - 0s 59us/sample - loss: 21.9265 - accuracy: 0.1180
Epoch 3/5
1000/1000 [==============================] - 0s 68us/sample - loss: 33.9123 - accuracy: 0.1210
Epoch 4/5
1000/1000 [==============================] - 0s 52us/sample - loss: 52.5335 - accuracy: 0.1190
Epoch 5/5
1000/1000 [==============================] - 0s 40us/sample - loss: 85.4629 - accuracy: 0.1180
Out[9]:
<tensorflow.python.keras.callbacks.History at 0x7f0d40696fd0>
 

4.2 模型子类化

可以通过对 tf.keras.Model 进行子类化并定义自己的前向传播来构建完全可自定义的模型。

  • 在__init__ 方法中创建层并将它们设置为类实例的属性。
  • 在__call__方法中定义前向传播
In [10]:
class MyModel(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(MyModel, self).__init__(name='my_model')
        self.num_classes = num_classes
        # 定义网络层
        self.layer1 = layers.Dense(32, activation='relu')
        self.layer2 = layers.Dense(num_classes, activation='softmax')
    def call(self, inputs):
        # 定义前向传播
        h1 = self.layer1(inputs)
        out = self.layer2(h1)
        return out
    
    def compute_output_shape(self, input_shape):
        # 计算输出shape
        shape = tf.TensorShape(input_shape).as_list()
        shape[-1] = self.num_classes
        return tf.TensorShape(shape)
# 实例化模型类,并训练
model = MyModel(num_classes=10)
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
             loss=tf.keras.losses.categorical_crossentropy,
             metrics=['accuracy'])

model.fit(train_x, train_y, batch_size=16, epochs=5)
 
Train on 1000 samples
Epoch 1/5
1000/1000 [==============================] - 1s 1ms/sample - loss: 15.5352 - accuracy: 0.1130
Epoch 2/5
1000/1000 [==============================] - 0s 84us/sample - loss: 23.1448 - accuracy: 0.1150
Epoch 3/5
1000/1000 [==============================] - 0s 75us/sample - loss: 31.5394 - accuracy: 0.1000
Epoch 4/5
1000/1000 [==============================] - 0s 70us/sample - loss: 39.0555 - accuracy: 0.1040
Epoch 5/5
1000/1000 [==============================] - 0s 77us/sample - loss: 45.8000 - accuracy: 0.1010
Out[10]:
<tensorflow.python.keras.callbacks.History at 0x7f0d400d8dd8>
 

4.3 自定义层

通过对 tf.keras.layers.Layer 进行子类化并实现以下方法来创建自定义层:

  • __init__: (可选)定义该层要使用的子层
  • build:创建层的权重。使用 add_weight 方法添加权重。

  • call:定义前向传播。

  • compute_output_shape:指定在给定输入形状的情况下如何计算层的输出形状。

  • 可选,可以通过实现 get_config 方法和 from_config 类方法序列化层。
In [11]:
class MyLayer(layers.Layer):
    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)
    
    def build(self, input_shape):
        shape = tf.TensorShape((input_shape[1], self.output_dim))
        self.kernel = self.add_weight(name='kernel1', shape=shape,
                                   initializer='uniform', trainable=True)
        super(MyLayer, self).build(input_shape)
    
    def call(self, inputs):
        return tf.matmul(inputs, self.kernel)

    def compute_output_shape(self, input_shape):
        shape = tf.TensorShape(input_shape).as_list()
        shape[-1] = self.output_dim
        return tf.TensorShape(shape)

    def get_config(self):
        base_config = super(MyLayer, self).get_config()
        base_config['output_dim'] = self.output_dim
        return base_config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

# 使用自定义网络层构建模型
model = tf.keras.Sequential(
[
    MyLayer(10),
    layers.Activation('softmax')
])


model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
             loss=tf.keras.losses.categorical_crossentropy,
             metrics=['accuracy'])

model.fit(train_x, train_y, batch_size=16, epochs=5)
 
Train on 1000 samples
Epoch 1/5
1000/1000 [==============================] - 0s 258us/sample - loss: 11.5199 - accuracy: 0.0860
Epoch 2/5
1000/1000 [==============================] - 0s 67us/sample - loss: 11.5205 - accuracy: 0.0870
Epoch 3/5
1000/1000 [==============================] - 0s 69us/sample - loss: 11.5205 - accuracy: 0.0850
Epoch 4/5
1000/1000 [==============================] - 0s 72us/sample - loss: 11.5201 - accuracy: 0.0830
Epoch 5/5
1000/1000 [==============================] - 0s 67us/sample - loss: 11.5201 - accuracy: 0.0810
Out[11]:
<tensorflow.python.keras.callbacks.History at 0x7f0d242d8320>
 

4.3 回调

回调是传递给模型以自定义和扩展其在训练期间的行为的对象。我们可以编写自己的自定义回调,或使用tf.keras.callbacks中的内置函数,常用内置回调函数如下:

  • tf.keras.callbacks.ModelCheckpoint:定期保存模型的检查点。
  • tf.keras.callbacks.LearningRateScheduler:动态更改学习率。
  • tf.keras.callbacks.EarlyStopping:验证性能停止提高时进行中断培训。
  • tf.keras.callbacks.TensorBoard:使用TensorBoard监视模型的行为 。
In [12]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
    tf.keras.callbacks.TensorBoard(log_dir='./logs')
]
model.fit(train_x, train_y, batch_size=16, epochs=5,
         callbacks=callbacks, validation_data=(val_x, val_y))
 
Train on 1000 samples, validate on 200 samples
Epoch 1/5
1000/1000 [==============================] - 0s 142us/sample - loss: 11.5200 - accuracy: 0.0870 - val_loss: 11.7025 - val_accuracy: 0.0750
Epoch 2/5
1000/1000 [==============================] - 0s 88us/sample - loss: 11.5206 - accuracy: 0.0900 - val_loss: 11.7015 - val_accuracy: 0.0950
Epoch 3/5
1000/1000 [==============================] - 0s 88us/sample - loss: 11.5197 - accuracy: 0.0810 - val_loss: 11.7022 - val_accuracy: 0.1100
Epoch 4/5
1000/1000 [==============================] - 0s 91us/sample - loss: 11.5198 - accuracy: 0.0810 - val_loss: 11.7029 - val_accuracy: 0.1250
Out[12]:
<tensorflow.python.keras.callbacks.History at 0x7f0d086590f0>
 

5 模型保存与恢复

5.1 权重保存

In [17]:
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),  # 需要有input_shape
layers.Dense(10, activation='softmax')])

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
In [18]:
# 权重保存与重载
model.save_weights('./weights/model')
model.load_weights('./weights/model')
# 保存为h5格式
model.save_weights('./model.h5', save_format='h5')
model.load_weights('./model.h5')
 

5.2 保存网络结构

In [19]:
# 序列化成json
import json
import pprint
json_str = model.to_json()
pprint.pprint(json.loads(json_str))
# 从json中重建模型
fresh_model = tf.keras.models.model_from_json(json_str)
 
{'backend': 'tensorflow',
 'class_name': 'Sequential',
 'config': {'layers': [{'class_name': 'Dense',
                        'config': {'activation': 'relu',
                                   'activity_regularizer': None,
                                   'batch_input_shape': [None, 32],
                                   'bias_constraint': None,
                                   'bias_initializer': {'class_name': 'Zeros',
                                                        'config': {}},
                                   'bias_regularizer': None,
                                   'dtype': 'float32',
                                   'kernel_constraint': None,
                                   'kernel_initializer': {'class_name': 'GlorotUniform',
                                                          'config': {'seed': None}},
                                   'kernel_regularizer': None,
                                   'name': 'dense_23',
                                   'trainable': True,
                                   'units': 64,
                                   'use_bias': True}},
                       {'class_name': 'Dense',
                        'config': {'activation': 'softmax',
                                   'activity_regularizer': None,
                                   'bias_constraint': None,
                                   'bias_initializer': {'class_name': 'Zeros',
                                                        'config': {}},
                                   'bias_regularizer': None,
                                   'dtype': 'float32',
                                   'kernel_constraint': None,
                                   'kernel_initializer': {'class_name': 'GlorotUniform',
                                                          'config': {'seed': None}},
                                   'kernel_regularizer': None,
                                   'name': 'dense_24',
                                   'trainable': True,
                                   'units': 10,
                                   'use_bias': True}}],
            'name': 'sequential_6'},
 'keras_version': '2.2.4-tf'}
In [20]:
# 保持为yaml格式  #需要提前安装pyyaml

yaml_str = model.to_yaml()
print(yaml_str)
# 从yaml数据中重新构建模型
fresh_model = tf.keras.models.model_from_yaml(yaml_str)
 
backend: tensorflow
class_name: Sequential
config:
  layers:
  - class_name: Dense
    config:
      activation: relu
      activity_regularizer: null
      batch_input_shape: !!python/tuple [null, 32]
      bias_constraint: null
      bias_initializer:
        class_name: Zeros
        config: {}
      bias_regularizer: null
      dtype: float32
      kernel_constraint: null
      kernel_initializer:
        class_name: GlorotUniform
        config: {seed: null}
      kernel_regularizer: null
      name: dense_23
      trainable: true
      units: 64
      use_bias: true
  - class_name: Dense
    config:
      activation: softmax
      activity_regularizer: null
      bias_constraint: null
      bias_initializer:
        class_name: Zeros
        config: {}
      bias_regularizer: null
      dtype: float32
      kernel_constraint: null
      kernel_initializer:
        class_name: GlorotUniform
        config: {seed: null}
      kernel_regularizer: null
      name: dense_24
      trainable: true
      units: 10
      use_bias: true
  name: sequential_6
keras_version: 2.2.4-tf

 

注意:子类模型不可序列化,因为其体系结构由call方法主体中的Python代码定义。

 

5.3 保存整个模型

In [21]:
model = tf.keras.Sequential([
  layers.Dense(10, activation='softmax', input_shape=(72,)),
  layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit(train_x, train_y, batch_size=32, epochs=5)
# 保存整个模型
model.save('all_model.h5')
# 导入整个模型
model = tf.keras.models.load_model('all_model.h5')
 
Train on 1000 samples
Epoch 1/5
1000/1000 [==============================] - 0s 380us/sample - loss: 11.5446 - accuracy: 0.1030
Epoch 2/5
1000/1000 [==============================] - 0s 43us/sample - loss: 11.5470 - accuracy: 0.1090
Epoch 3/5
1000/1000 [==============================] - 0s 46us/sample - loss: 11.5488 - accuracy: 0.1090
Epoch 4/5
1000/1000 [==============================] - 0s 44us/sample - loss: 11.5616 - accuracy: 0.1090
Epoch 5/5
1000/1000 [==============================] - 0s 43us/sample - loss: 11.5849 - accuracy: 0.1090
 

6 将keras用于Estimator

Estimator API 用于针对分布式环境训练模型。它适用于一些行业使用场景,例如用大型数据集进行分布式训练并导出模型以用于生产

In [22]:
model = tf.keras.Sequential([layers.Dense(10,activation='softmax'),
                          layers.Dense(10,activation='softmax')])

model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

estimator = tf.keras.estimator.model_to_estimator(model)
 
WARNING: Logging before flag parsing goes to stderr.
W1003 15:52:41.191763 139697242609472 estimator.py:1821] Using temporary folder as model directory: /tmp/tmpliq4yvi6
 

7 Eager execution

 

Eager execution是一个动态执行的编程环境,它可以立即评估操作。Keras不需要此功能,但它受tf.keras程序支持和对检查程序和调试有用。

所有的tf.keras模型构建API都与Eager execution兼容。尽管可以使用Sequential和函数API,但Eager execution有利于模型子类化和构建自定义层:其要求以代码形式编写前向传递的API(而不是通过组装现有层来创建模型的API)。

 

8 多GPU上运行

 

tf.keras模型可使用tf.distribute.Strategy在多个GPU上运行 。该API在多个GPU上提供了分布式培训,几乎无需更改现有代码。

当前tf.distribute.MirroredStrategy是唯一受支持的分发策略。MirroredStrategy在单台计算机上使用全缩减进行同步训练来进行图内复制。要使用 distribute.Strategys,请将优化器实例化以及模型构建和编译嵌套在Strategys中.scope(),然后训练模型。

以下示例tf.keras.Model在单个计算机上的多GPU分配。

首先,在分布式策略范围内定义一个模型:

In [23]:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = tf.keras.Sequential()
    model.add(layers.Dense(16, activation='relu', input_shape=(10,)))
    model.add(layers.Dense(1, activation='sigmoid'))
    optimizer = tf.keras.optimizers.SGD(0.2)
    model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
 
W1003 15:52:43.881619 139697242609472 cross_device_ops.py:1209] There is non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
 
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_29 (Dense)             (None, 16)                176       
_________________________________________________________________
dense_30 (Dense)             (None, 1)                 17        
=================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________
 

然后像单gpu一样在数据上训练模型即可

In [24]:
x = np.random.random((1024, 10))
y = np.random.randint(2, size=(1024, 1))
x = tf.cast(x, tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(buffer_size=1024).batch(32)
model.fit(dataset, epochs=1)
 
32/32 [==============================] - 1s 42ms/step - loss: 0.7060
Out[24]:
<tensorflow.python.keras.callbacks.History at 0x7f0cde166860>
In [ ]:
 
In [ ]:
 
 posted on 2020-10-19 16:53  大码王  阅读(302)  评论(0编辑  收藏  举报
复制代码