第8章 Keras高层接口

Keras 是一个主要由 Python 语言开发的开源神经网络计算库。只能使用 Keras 的接口来完成 TensorFlow 层方式的模型搭建与训练。在 TensorFlow 中, Keras 被实现在 tf.keras 子模块中。

8.1 常见功能模块

Keras 提供了一系列高层的神经网络类和函数,如常见数据集加载函数,网络层类模型容器损失函数类优化器类经典模型类等等 。

8.1.1 常见网络层类

对于常见的神经网络层,可以使用张量方式的底层接口函数来实现,这些接口函数一般在tf.nn模块中。

对于常见的网络层,一般直接使用层方式来完成模型的搭建:

tf.keras.layers 命名空间中提供了大量常见网络层的类接口,如全连接层, 激活含水层, 池化层,卷积层,循环神经网络层等等。对于这些网络层类,只需要在创建时指定网络层的相关参数, 并调用__call__方法即可完成前向计算。(注:在调用__call__方法时, Keras 会自动调用每个层的前向传播逻辑,这些逻辑一般实现
在类的 call 函数中 )

例(Softmax):

import tensorflow as tf
# 导入 keras 模型,不能使用 import keras,它导入的是标准的 Keras 库
from tensorflow import keras
from tensorflow.keras import layers
x = tf.constant([2.,1.,0.1])
layer = layer.Softmax(axis=-1)
layer(x)

8.1.2 网络容器

当网络层数变得较深时 , 可以通过Keras提供的网络容器Sequential将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序运算。

from tensorflow.keras import layers, Sequential
network = Sequential([
    layers.Dense(3, activation=None),
    layers.ReLU(),
    layers.Dense(2, activation=None),
    layers.ReLU()
])
x = tf.random.normal([4, 3])
network(x)

Sequential通过add()方法继续追加新的网络层 ,实现动态创建网络的功能:

layers_num = 2
network = Sequential([])
for _ in range(layers_num):
    network.add(layers.Dense(3))
    network.add(layers.ReLU())
network.build(input_shape=(None, 4))  # 创建网络参数
network.summary()

Sequential 对象的trainable_variables variables包含了所有层的待优化张量列表和全部张量列表。

# 打印网络的待优化参数名与 shape
for p in network.trainable_variables:
    print(p.name, p.shape)

8.2 模型装配、训练与测试

在训练网络时,一般的流程是通过前向计算获得网络的输出值, 再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。

8.2.1模型装配

keras.Modelkeras.layers.Layer

  • Layer类是网络层的母类,定义了网络层的一些常见功能,如添加权值,管理权值列表等。
  • Model 类是网络的母类,除了具有 Layer 类的功能,还添加了保存、加载模型,训练与测试模型等便捷功能 。Sequential 也是 Model 的子类。

创建5层全连接网络用语MNIST手写数字图片识别:

network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()

创建网络后的正常流程是,通过循环迭代数据集多遍,每次按批产生训练数据,前向计算,然后通过损失函数计算误差值,并反向传播自动计算梯度,更新网络参数。

在keras中提供了compile()fit()函数方便实现创建网络后的正常流程。

  • compile 函数指定网络使用的优化器对象,损失函数, 评价指标等:
# 采用 Adam 优化器,学习率为 0.01;采用交叉熵损失函数,包含 Softmax
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

8.2.2 模型训练

模型装配完成后,即可通过fit()函数送入待训练的数据和验证用的数据集:

# 指定训练集为 train_db,验证集为 val_db,训练 5 个 epochs,每 2 个 epoch 验证一次
# 返回训练信息保存在 history 中
history = network.fit(train_db, epochs=5, validation_data=val_db,
validation_freq=2)

fit函数会返回训练过程的数据记录

8.2.3模型测试

Model基类除了可以便捷地完成网络的装配与训练、验证,还可以非常方便的预测和测试。

通过Model.predict(x)方法即可完成模型的预测:

x, y = next(iter(db_test))
print('predict x:', x.shape)
out = network.predict(x)
print(out)

简单测试模型性能:

network.evaluate(db_test)

8.3 模型保存与加载

在Keras中,有三种常用的模型保存与加载方法

8.3.1张量方式

网络的状态主要体现在网络的结构以及网络层内部张量参数上,因此在拥有网络结构源文件的条件下,直接保存网络张量参数到文件上是最轻量级的一种方式。

通过调用Model.save_weights(path)方法即可将当前的网络参数保存到path文件上:

network.save_weights('weights.ckpt')

只需要先创建好网络对象, 然后调用网络对象的 load_weights(path)方法即可将指定的模型文件中保存的张量数值写入到当前网络参数中去:

# 保存模型参数到文件上
network.save_weights('weights.ckpt')
print('saved weights')
del network  # 删除网络对象
# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=losses.CategoricalCrossentropy(from_logits=True),
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
print('loaded weights!')

这种保存与加载网络的方式最为轻量级, 文件中保存的仅仅是参数张量的数值,并没有其他额外的结构参数。但是它需要使用相同的网络结构才能够恢复网络状态,因此一般在拥有网络源文件的情况下使用。

8.3.2网络方式

不需要网络源文件,仅仅需要模型参数即可恢复网络模型的方式。

通过 Model.save(path)函数可以将模型的结构以及模型的参数保存到一个 path 文件上,在不需要网络源文件的条件下,通过 keras.models.load_model(path)即可恢复网络结构和网络参数。

# 保存模型结构与模型参数到文件
network.save('model.h5')
print('saved total model.')
del network # 删除网络对象

此时通过 model.h5 文件即可恢复出网络的结构和状态:

network = tf.keras.models.load_model('model.h5')

文件除了保存了模型参数外,还保存了网络结构信息,不需要提前创
建模型即可直接从文件中恢复出网络 network 对象 。

8.3.3SavedModel方式

通过tf.keras.experimental.export_saved_model(network,path)即可将模型以SavedModel方式保存到path目录中:

# 保存模型结构与模型参数到文件
tf.keras.experimental.export_save_model(network, 'model-savedmodel')
print('export saved model.')
del network # 删除网络对象
# 从文件恢复网络结构与网络参数
network = tf.keras.experimental.load_from_saved_model('model-savedmodel')

即可恢复出网络结构和参数,方便各个平台能够无缝对接训练好的网络模型。

8.4 自定义类

在创建自定义网络层类时, 需要继承自layers.Layer基类,创建自定义的网络类,需要继承自keras.Model 基类 。

8.4.1 自定义网络层

对于自定义的网络层,需要实现初始化__init__方法和前向传播逻辑 call 方法 .

自定义网络层:

class MyDense(layers.Layer):
    # 自定义网络层
    def __init__(self, inp_dim, outp_dim):
        super(MyDense, self).__init__()
        # 创建权值张量并添加到类管理列表中,设置为需要优化
        self.kernel = self.add_variable('w', [inp_dim, outp_dim], trainable=True)
net = MyDense(4, 3)

前向计算逻辑:

def call(self, inputs, training=None):
    # 实现自定义类的前向计算逻辑
    # X@W
    out = inputs @ self.kernel
    # 执行激活函数运算
    out = tf.nn.relu(out)
    return out

自定义类的前向运算逻辑需要实现在call(inputs,training)函数中

  • inputs 输入,调用传入
  • trainning参数用于指定模型的状态
    • training 为 True ——训练模式
    • training 为False——测试模式
    • 默认参数为None——测试模式

全连接层的训练模式和测试模式逻辑一致,不需要额外处理。对于部份测试模式和训练模式不一致的网络层,需要根据 training 参数来设计需要执行的逻辑。

8.4.2自定义网络

自定义的类可以和其他标准类一样,通过 Sequential 容器方便地包裹成一个网络模型 :

network = Sequential([MyDense(784, 256),
                      MyDense(256, 128),
                      MyDense(128, 64),
                      MyDense(64, 32),
                      MyDense(32, 10)])
network.build(input_shape=(None, 28*28))
network.summary()

更普遍地 ,可以继承基类来实现任意逻辑的自定义网络类创建自定
义网络类,首先创建并继承 Model 基类,分布创建对应的网络层对象:

class MyModel(keras.Model):
    # 自定义网络类,继承自Model基类
    def __init(self):
        super(MyModel, self).__init__()
        #  完成网络内需要的网络层的创建工作
        self.fc1 = MyDense(28*28, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, training=None):
        x = self.fc1(inputs)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        x = self.fc5(x)
        return x

这种方法更加灵活。

8.5模型乐园

对于常用的网络模型,如 ResNetVGG等,不需要手动创建网络,可以直接从keras.applications 子模块下一行代码即可创建并使用这些经典模型,同时还可以通过设置weights 参数加载预训练的网络参数,非常方便。

8.5.1加载模型

以 ResNet50 迁移学习为例,一般将 ResNet50 去掉最后一层后的网络作为新任务的特征提取子网络,即利用 ImageNet 上面预训练的特征提取方法迁移到我们自定义的数据集上,并根据自定义任务的类别追加一个对应数据类别数的全连接分类层, 从而可以在预训练网络的基础上可以快速、高效地学习新任务。

  1. 首先利用 Keras 模型乐园加载 ImageNet 预训练的 ResNet50 网络
# 加载ImageNet 预训练网络模型,并去掉最后一层
resnet = keras.applications.ResNet50(weights='imagenet', include_top=False)
resnet.summary()
# 测试网络的输出
x = tf.random.normal([4, 224, 224, 3])
out = resnet(x)
out.shape
  1. 需要设置自定义的输出节点数,以 100 类的分类任务为例新建一个
    池化层(这里的池化层可以理解为维度缩减功能),将特征[b,7,7,2048]降维到[b,2048]

    # 新建池化层
    global_average_layer = layers.GlobalAveragePooling2D()
    # 利用上一层的输出作为本层的输入,测试其输出
    x = tf.random.normal([4, 7, 7, 2048])
    out = global_average_layer(x)  # 池化层降维
    print(out.shape)
    
  2. 最后新建一个全连接层,并设置输出节点数为 100

# 新建全连接层
fc = layers.Dense(100)
# 利用上一层的输出作为本层的输入,测试其输出
x = tf.random.normal([4, 2048])
out = fc(x)
print(out.shape)
  1. 在得到预训练的 ResNet50 特征层和我们新建的池化层、全连接层后,我们重新利用Sequential 容器封装成一个新的网络:

    # 重新包裹成我们的网络模型
    mynet = Sequential([resnet, global_average_layer, fc])
    mynet.summary()
    
  2. 通过设置 resnet.trainable = False 可以选择冻结 ResNet 部分的网络参数,只训练新建的网络层,从而快速、高效完成网络模型的训练。

8.6 测量工具

在网络的训练过程中,经常需要统计准确率,召回率等信息,Keras 提供了一些常用的测量工具 keras.metrics,专门用于统计训练过程中需要的指标数据。

Keras 的测量工具的使用一般有 4 个基本操作流程: 新建测量器写入数据读取统计数据清零测量器

8.6.1新建测量器

keras.metrics模块提供了较多的常用测量类:

  • 统计平均值的Mean
  • 统计准确率的 Accuracy
  • 统计余弦相似度的CosineSimilarity

在前向运算时,我们会得到每一个 batch 的平均误差,但是我们希望统计一个epoch 的平均误差,因此我们选择使用 Mean 测量器:

# 新建平均测量器,适合Loss数据
loss_meter = metrics.Mean()

8.6.2 写入数据

通过测量器的update_state函数可以写入新的数据:

# 记录采样的数据
loss_meter.update_state(float(loss))

上述采样代码放置在每个 batch 运算完成后, 测量器会自动根据采样的数据来统计平均值。

8.6.3 读取统计信息

在采样多次后,可以测量器的result()函数获取统计值:

# 打印统计的平均loss
print(step, 'loss:', loss_meter.result())

8.6.4 清除

测量器会统计所有历史记录的数据,在合适的时候有必要清除历史状态。通过 reset_states()即可实现 。

例如,在每次读取完平均误差后, 清零统计信息,以便下一轮统计的开始:

if step % 100 ==0:
    print(step, 'loss:', loss_meter.result())
    loss_meter.reset_states()  # 清零测量器

8.7 可视化

TensorFlow 提供了一个专门的可视化工具,叫做TensorBoard。TensorBoard 的使用需要训练部分和浏览器交互工作。

posted @ 2021-11-24 15:49  Reversal-destiny  阅读(72)  评论(0编辑  收藏  举报