Pytorch——Pytorch Lightning框架使用手册
本文主要是记录下,使用PytorchLightning这个如何进行深度学习的训练,记录一下本人平常使用这个框架所需要注意的地方,由于框架的理解深入本文会时不时进行更新(第三部分的常见问题会是不是的更新走的),本文深度参考以下两个网站pytorch_lightning 全程笔记 、Pytorch Lightning 完全攻略如果大家觉得本文写得不是很清楚,大家可以进一步看看这两篇文章。
一、框架使用方案
正如网络上大家介绍的那样,PL框架可以让人专心在模型内部的研究。我们在复杂的项目中,可能会出现多个模型,并且模型多个模型之间存在着许多的联系,如果在项目中想要更换某些模型model,会导致重写很多代码。但是如果采用PL框架,那么这将会是一件比较容易的事情。根据Pytorch Lightning 完全攻略这篇文章的推荐,我建议采用以下的代码风格:
root-
|-dataModule
|-__init__.py
|-data_interface.py
|-xxxdataset1.py
|-xxxdataset2.py
|-...
|-modelModule
|-__init__.py
|-model_interface.py
|-xxxmodel1.py
|-xxxmodel2.py
|-...
|-train.py
其中把dataModule和modelModule写成python包,这两个包的__init__.py分别是:
- from .data_interface import DInterface
- from .model_interface import MInterface
在DInterface和MInterface分别是data_interface.py和model_interface.py中创建的类,他们两个分别就是
- class DInterface(pl.LightningDataModule): 用于所有数据集的接口,在setup()方法中初始化你准备好的xxxdataset1.py,xxxdataset2.py中定义的torch.utils.data.Dataset类。在train_dataloader,val_dataloader,test_dataloader这几个方法中载入Dataloader即可。
- class MInterface(pl.LightningModule): 用作模型的接口,在
__init__()
函数中import你准备好的xxxmodel2.py,xxxmodel1.py这些模型。重写training_step方法,validation_step方法,configure_optimizers方法。
当大家在更改模型的时候只需要在对应的模块上进行更改即可,最后train.py主要功能就是读取参数,和调用dataModule和modelModule这两个包进行实例化DInterface和MInterface,当然一些PL框架的回调函数也需要在train.py里进行定义。
二、框架基本模块(Module)
2.1 LightningModule
LightningModule必须包含的部分是init()和training_step(self, batch, batch_idx),其中init()主要是进行模型的初始化和定义(不需要定义数据集等)。training_step(...)主要是进行定义每个batch数据的处理函数。
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
这里的batch就是从dataloader中出来的一个batch的数据,类似于for batch in dataloader :。可以看到这个函数的返回值是一个loss,文档中有提到这个方法一定要返回一个loss,如果是返回一个字典,那么必须包含“loss”这个键。当然也有例外,当我们重写training_step_end()方法的时候就不用training_step必须返回一个loss了,此时可以返回任意的东西,但是要注意training_step_end()方法就必须返回一个loss,至于training_step_end()主要是在使用多GPU训练的时候需要重写该方法,主要是进行损失的汇总。
除了training_step,我们还有validation_step,test_step,其中test_step不会在训练中调用,而validation_step则是对测试数据进行模型推理,一般在这个步骤里可以用self.log进行记录某些值,例如:
def validation_step(self, batch, batch_idx):
pre = model(batch)
loss = self.lossfun(...)
# log记录
self.log('val_loss',loss, on_epoch=True, prog_bar=True, logger=True)
上面的使用的self.log是非常重要的一个方法,这个方法继承自LightningModule这个父类,我们使用这里log就可以在训练时使用ModelCheckpoint对象(用于保存模型的参数对象)去检测测试步骤中的参数(比如这里我们就可以检测val_loss这个值,来确定是否保存这个模型参数)
self.log()中常用参数以下:
- prog_bar:如果是True,该值将会显示在进度条上
- logger:如果是True,将会记录到logger器中(会显示在tensorboard上)
2.2 LightningDataModule
这一个类必须包含的部分是setup(self, stage=None)方法,train_dataloader()方法。
- setup(self, stage=None):主要是进行Dataset的实例化,包括但不限于进行数据集的划分,划分成训练集和测试集,一般来说都是Dataset类
- train_dataloader():很简单,只需要返回一个DataLoader类即可。
-
val_dataloader():与train_dataloader一样,用于初始化DataLoader对象,并返回
有些时候也会定义collate_fn函数,在DataLoader创建时传入collate_fn参数,用于对Dataset进处理(但实际上一般是在Dataset类中定义,根据Dataset的属性类个性化配置)。
三、常见问题
参考网站:
pytorch_lightning 全程笔记 - 知乎 (zhihu.com)
Pytorch Lightning 完全攻略 - 知乎 (zhihu.com)
PyTorch Lightning初步教程(上) - 知乎 (zhihu.com)
“简约版”Pytorch —— Pytorch-Lightning详解_@YangZai的博客-CSDN博客
201024-5步PyTorchLightning中设置并访问tensorboard_专注机器学习之路-CSDN博客
PyTorch Ligntning】快速上手简明指南_闻韶-CSDN博客
PyTorch Lightning 工具学习 - 知乎 (zhihu.com)