使用@tf.function加快训练速度
TensorFlow 2 默认的即时执行模式(Eager Execution)为我们带来了灵活及易调试的特性,但为了追求更快的速度与更高的性能,我们依然希望使用 TensorFlow 1.X 中默认的图执行模式(Graph Execution)。此时,TensorFlow 2 为我们提供了 tf.function
模块,结合 AutoGraph 机制,使得我们仅需加入一个简单的 @tf.function
修饰符,就能轻松将模型以图执行模式运行。
实现方式
只需要将我们希望以图执行模式运行的代码封装在一个函数内,并在函数前加上 @tf.function
即可。
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import time
np.random.seed(42) # 设置numpy随机数种子
tf.random.set_seed(42) # 设置tensorflow随机数种子
# 生成训练数据
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1 # y=x^2+1 + 随机噪声
x_train = np.expand_dims(x, 1) # 将一维数据扩展为二维
y_train = np.expand_dims(y, 1) # 将一维数据扩展为二维
plt.plot(x, y, '.') # 画出训练数据
def create_model():
inputs = keras.Input((1,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = create_model() # 创建一个模型
loss_fn = keras.losses.MeanSquaredError() # 定义损失函数
optimizer = keras.optimizers.SGD() # 定义优化器
@tf.function # 将训练过程转化为图执行模式
def train():
with tf.GradientTape() as tape:
y_pred = model(x_train, training=True) # 前向传播,注意不要忘了training=True
loss = loss_fn(y_train, y_pred) # 计算损失
tf.summary.scalar("loss", loss, epoch+1) # 将损失写入tensorboard
grads = tape.gradient(loss, model.trainable_variables) # 计算梯度
optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 使用优化器进行反向传播
return loss
epochs = 1000
begin_time = time.time() # 训练开始时间
for epoch in range(epochs):
loss = train()
print('epoch:', epoch+1, '\t', 'loss:', loss.numpy()) # 打印训练信息
end_time = time.time() # 训练结束时间
print("训练时长:", end_time-begin_time)
# 预测
y_pre = model.predict(x_train)
# 画出预测值
plt.plot(x, y_pre.squeeze())
plt.show()
通过实验得出结论:如果不使用@tf.function
,那么训练时间大约为3秒。如果使用@tf.function
,训练时间仅需要0.5秒。快了很多倍。
内在原理
使用@tf.function
的函数在执行时会生成一个计算图,里面的操作就是计算图的每个节点。下次调用相同的函数,且参数类型相同时,则会直接使用这个计算图计算。若函数名不同或参数类型不同时,则会另外生成一个新的计算图。
注意点
建议在函数内只使用 TensorFlow 的原生操作,不要使用过于复杂的 Python 语句,函数参数最好只包括 TensorFlow 张量或 NumPy 数组。
-
因为只有tf的原生操作才会在计算图中生产节点。(如python的原生
print()
函数不会生成节点,而tensorflow的tf.print()
会) -
对于Tensorflow张量或Numpy数组作为参数的函数,只要类型相同便可重用之前的计算图。而对于python原声数据(如原生的整数、浮点数 1,1.5等)必须参数的值一模一样才会重用之前的计算图,否则的话会创建新的计算图。
另外,一般而言,当模型由较多小的操作组成的时候, @tf.function
带来的提升效果较大。而当模型的操作数量较少,但单一操作均很耗时的时候,则 @tf.function
带来的性能提升不会太大。