使用@tf.function加快训练速度

TensorFlow 2 默认的即时执行模式(Eager Execution)为我们带来了灵活及易调试的特性,但为了追求更快的速度与更高的性能,我们依然希望使用 TensorFlow 1.X 中默认的图执行模式(Graph Execution)。此时,TensorFlow 2 为我们提供了 tf.function 模块,结合 AutoGraph 机制,使得我们仅需加入一个简单的 @tf.function 修饰符,就能轻松将模型以图执行模式运行。

实现方式

只需要将我们希望以图执行模式运行的代码封装在一个函数内,并在函数前加上 @tf.function 即可。

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import time

np.random.seed(42)  # 设置numpy随机数种子
tf.random.set_seed(42)  # 设置tensorflow随机数种子

# 生成训练数据
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1  # y=x^2+1 + 随机噪声
x_train = np.expand_dims(x, 1)  # 将一维数据扩展为二维
y_train = np.expand_dims(y, 1)  # 将一维数据扩展为二维
plt.plot(x, y, '.')  # 画出训练数据


def create_model():
    inputs = keras.Input((1,))
    x = keras.layers.Dense(10, activation='relu')(inputs)
    outputs = keras.layers.Dense(1)(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = create_model()  # 创建一个模型
loss_fn = keras.losses.MeanSquaredError()  # 定义损失函数
optimizer = keras.optimizers.SGD()  # 定义优化器


@tf.function  # 将训练过程转化为图执行模式
def train():
    with tf.GradientTape() as tape:
        y_pred = model(x_train, training=True)  # 前向传播,注意不要忘了training=True
        loss = loss_fn(y_train, y_pred)  # 计算损失
        tf.summary.scalar("loss", loss, epoch+1)  # 将损失写入tensorboard
    grads = tape.gradient(loss, model.trainable_variables)  # 计算梯度
    optimizer.apply_gradients(zip(grads, model.trainable_variables))  # 使用优化器进行反向传播
    return loss


epochs = 1000
begin_time = time.time()  # 训练开始时间
for epoch in range(epochs):
    loss = train()

    print('epoch:', epoch+1, '\t', 'loss:', loss.numpy())  # 打印训练信息
end_time = time.time()  # 训练结束时间

print("训练时长:", end_time-begin_time)

# 预测
y_pre = model.predict(x_train)

# 画出预测值
plt.plot(x, y_pre.squeeze())
plt.show()

通过实验得出结论:如果不使用@tf.function,那么训练时间大约为3秒。如果使用@tf.function,训练时间仅需要0.5秒。快了很多倍。

内在原理

使用@tf.function的函数在执行时会生成一个计算图,里面的操作就是计算图的每个节点。下次调用相同的函数,且参数类型相同时,则会直接使用这个计算图计算。若函数名不同或参数类型不同时,则会另外生成一个新的计算图。

注意点

建议在函数内只使用 TensorFlow 的原生操作,不要使用过于复杂的 Python 语句,函数参数最好只包括 TensorFlow 张量或 NumPy 数组。

  • 因为只有tf的原生操作才会在计算图中生产节点。(如python的原生print()函数不会生成节点,而tensorflow的tf.print()会)

  • 对于Tensorflow张量或Numpy数组作为参数的函数,只要类型相同便可重用之前的计算图。而对于python原声数据(如原生的整数、浮点数 1,1.5等)必须参数的值一模一样才会重用之前的计算图,否则的话会创建新的计算图。

另外,一般而言,当模型由较多小的操作组成的时候, @tf.function 带来的提升效果较大。而当模型的操作数量较少,但单一操作均很耗时的时候,则 @tf.function 带来的性能提升不会太大。

参考

https://tf.wiki/zh_hans/basic/tools.html#tf-function

posted @ 2021-01-10 14:09  火锅先生  阅读(3096)  评论(0编辑  收藏  举报