6-3 使用单GPU训练模型——eat_tensorflow2_in_30_days
6-3 使用单GPU训练模型#
深度学习的训练过程常常非常耗时,一个模型训练几个小时是家常便饭,训练几天也是常有的事情,有时候甚至要训练几十天。
训练过程的耗时主要来自于两个部分,一部分来自数据准备,另一部分来自参数迭代。
当数据准备过程还是模型训练时间的主要瓶颈时,我们可以使用更多进程来准备数据。
当参数迭代过程成为训练时间的主要瓶颈时,我们通常的方法是应用GPU或者Google的TPU来进行加速。
详见《用GPU加速Keras模型——Colab免费GPU使用攻略》https://zhuanlan.zhihu.com/p/68509398
无论是内置fit方法,还是自定义训练循环,从CPU切换成单GPU训练模型都是非常方便的,无需更改任何代码。当存在可用的GPU时,如果不特意指定device,tensorflow会自动优先选择使用GPU来创建张量和执行张量计算。但如果是在公司或者学校实验室的服务器环境,存在多个GPU和多个使用者时,为了不让单个同学的任务占用全部GPU资源导致其他同学无法使用(tensorflow默认获取全部GPU的全部内存资源权限,但实际上只使用一个GPU的部分资源),我们通常会在开头增加以下几行代码以控制每个任务使用的GPU编号和显存大小,以便其他同学也能够同时训练模型。
无论是内置fit方法,还是自定义训练循环,从CPU切换成单GPU训练模型都是非常方便的,无需更改任何代码。当存在可用的GPU时,如果不特意指定device,tensorflow会自动优先选择使用GPU来创建张量和执行张量计算。
但如果是在公司或者学校实验室的服务器环境,存在多个GPU和多个使用者时,为了不让单个同学的任务占用全部GPU资源导致其他同学无法使用(tensorflow默认获取全部GPU的全部内存资源权限,但实际上只使用一个GPU的部分资源),我们通常会在开头增加以下几行代码以控制每个任务使用的GPU编号和显存大小,以便其他同学也能够同时训练模型。
import tensorflow as tf
print(tf.__version__)
"""
2.6.0
"""
from tensorflow.keras import *
#打印时间分割线
@tf.function
def printbar():
ts = tf.timestamp()
today_ts = ts%(24*60*60)
hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
minite = tf.cast((today_ts%3600)//60,tf.int32)
second = tf.cast(tf.floor(today_ts%60),tf.int32)
def timeformat(m):
if tf.strings.length(tf.strings.format("{}",m))==1:
return(tf.strings.format("0{}",m))
else:
return(tf.strings.format("{}",m))
timestring = tf.strings.join([timeformat(hour),timeformat(minite),
timeformat(second)],separator = ":")
tf.print("=========="*8,end = "")
tf.print(timestring)
GPU设置#
gpus = tf.config.list_physical_devices("GPU")
if gpus:
gpu0 = gpus[0] # 如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) # 设置GPU显存用量按需使用
# 或者也可以设置GPU显存为固定使用量(例如:4G)
# tf.config.experimental.set_virtual_device_configuration(
# gpu0, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)]
#)
tf.config.set_visible_devices([gpu0], "GPU")
- 比较GPU和CPU的计算速度
printbar()
with tf.device("/gpu:0"):
tf.random.set_seed(0)
a = tf.random.uniform((10000, 1000), minval=0, maxval=3.0)
b = tf.random.uniform((1000, 100000), minval=0, maxval=3.0)
c = a@b
tf.print(tf.reduce_sum(tf.reduce_sum(c, axis=0), axis=0))
printbar()
"""
================================================================================16:38:10
2.24975205e+12
================================================================================16:38:10
"""
printbar()
with tf.device("/cpu:0"):
tf.random.set_seed(0)
a = tf.random.uniform((10000, 1000), minval=0, maxval=3.0)
b = tf.random.uniform((1000, 100000), minval=0, maxval=3.0)
c = a@b
tf.print(tf.reduce_sum(tf.reduce_sum(c, axis=0), axis=0))
printbar()
"""
================================================================================16:38:19
2.24977984e+12
================================================================================16:38:22
"""
准备数据#
MAX_LEN = 300
BATCH_SIZE = 32
(x_train, y_train), (x_test, y_test) = datasets.reuters.load_data()
x_train = preprocessing.sequence.pad_sequences(x_train, maxlen=MAX_LEN)
x_test = preprocessing.sequence.pad_sequences(x_test, maxlen=MAX_LEN)
MAX_WORDS = x_train.max() + 1
CAT_NUM = y_train.max() + 1
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) \
.shuffle(buffer_size=1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE).cache()
ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) \
.shuffle(buffer_size=1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE).cache()
定义模型#
tf.keras.backend.clear_session()
def create_model():
model = models.Sequential()
model.add(layers.Embedding(MAX_WORDS, 7, input_length=MAX_LEN))
model.add(layers.Conv1D(filters=64, kernel_size=5, activation="relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Conv1D(filters=32, kernel_size=3, activation="relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Flatten())
model.add(layers.Dense(CAT_NUM, activation="softmax"))
return model
model = create_model()
model.summary()
"""
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 300, 7) 216874
_________________________________________________________________
conv1d (Conv1D) (None, 296, 64) 2304
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64) 0
_________________________________________________________________
conv1d_1 (Conv1D) (None, 146, 32) 6176
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 2336) 0
_________________________________________________________________
dense (Dense) (None, 46) 107502
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
"""
训练模型#
optimizer = optimizers.Nadam()
loss_func = losses.SparseCategoricalCrossentropy()
train_loss = metrics.Mean(name="train_loss")
train_metric = metrics.SparseCategoricalAccuracy(name="train_accuracy")
valid_loss = metrics.Mean(name="valid_loss")
valid_metric = metrics.SparseCategoricalAccuracy(name="valid_accuracy")
@tf.function
def train_step(model, features, labels):
with tf.GradientTape() as tape:
predictions = model(features, training=True)
loss = loss_func(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
train_metric.update_state(labels, predictions)
@tf.function
def valid_step(model, features, labels):
predictions = model(features)
batch_loss = loss_func(labels, predictions)
valid_loss.update_state(batch_loss)
valid_metric.update_state(labels, predictions)
def train_model(model, ds_train, ds_valid, epochs):
for epoch in tf.range(1, epochs+1):
for features, labels in ds_train:
train_step(model, features, labels)
for features, labels, in ds_valid:
valid_step(model, features, labels)
logs = "Epoch={},Loss={},Accuracy={},Valid Loss={},Valid Accuracy={}"
if epoch % 1 == 0:
printbar()
tf.print(tf.strings.format(logs, (epoch, train_loss.result(), train_metric.result(),
valid_loss.result(), valid_metric.result())))
tf.print()
train_loss.reset_states()
valid_loss.reset_states()
train_metric.reset_states()
valid_metric.reset_states()
train_model(model, ds_train, ds_test, 10)
"""
================================================================================17:05:28
Epoch=1,Loss=0.211605832,Accuracy=0.94388777,Valid Loss=3.39473462,Valid Accuracy=0.61843276
================================================================================17:05:29
Epoch=2,Loss=0.192161322,Accuracy=0.945891798,Valid Loss=3.67727017,Valid Accuracy=0.613535166
================================================================================17:05:31
Epoch=3,Loss=0.178221658,Accuracy=0.945446432,Valid Loss=3.8808856,Valid Accuracy=0.610863745
================================================================================17:05:31
Epoch=4,Loss=0.162928045,Accuracy=0.94611448,Valid Loss=4.25203,Valid Accuracy=0.605966151
================================================================================17:05:33
Epoch=5,Loss=0.152047396,Accuracy=0.948341131,Valid Loss=4.56476116,Valid Accuracy=0.604185224
================================================================================17:05:34
Epoch=6,Loss=0.14381744,Accuracy=0.949565828,Valid Loss=4.83808756,Valid Accuracy=0.60195905
================================================================================17:05:35
Epoch=7,Loss=0.137525707,Accuracy=0.950790465,Valid Loss=4.87792158,Valid Accuracy=0.607747078
================================================================================17:05:37
Epoch=8,Loss=0.130401462,Accuracy=0.952460468,Valid Loss=5.04861832,Valid Accuracy=0.607747078
================================================================================17:05:39
Epoch=9,Loss=0.123182565,Accuracy=0.953573823,Valid Loss=5.1263938,Valid Accuracy=0.605075717
================================================================================17:05:40
Epoch=10,Loss=0.12008135,Accuracy=0.955021143,Valid Loss=5.22552586,Valid Accuracy=0.605966151
"""
作者:lotuslaw
出处:https://www.cnblogs.com/lotuslaw/p/16437958.html
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧