Pytorch——Pytorch Lightning框架使用手册

  本文主要是记录下,使用PytorchLightning这个如何进行深度学习的训练,记录一下本人平常使用这个框架所需要注意的地方,由于框架的理解深入本文会时不时进行更新(第三部分的常见问题会是不是的更新走的),本文深度参考以下两个网站pytorch_lightning 全程笔记 Pytorch Lightning 完全攻略如果大家觉得本文写得不是很清楚,大家可以进一步看看这两篇文章。

一、框架使用方案

  正如网络上大家介绍的那样,PL框架可以让人专心在模型内部的研究。我们在复杂的项目中,可能会出现多个模型,并且模型多个模型之间存在着许多的联系,如果在项目中想要更换某些模型model,会导致重写很多代码。但是如果采用PL框架,那么这将会是一件比较容易的事情。根据Pytorch Lightning 完全攻略这篇文章的推荐,我建议采用以下的代码风格:

root-
    |-dataModule
        |-__init__.py
        |-data_interface.py
        |-xxxdataset1.py
        |-xxxdataset2.py
        |-...
    |-modelModule
        |-__init__.py
        |-model_interface.py
        |-xxxmodel1.py
        |-xxxmodel2.py
        |-...
    |-train.py

  其中把dataModule和modelModule写成python包,这两个包的__init__.py分别是:

  • from .data_interface import DInterface
  • from .model_interface import MInterface

  在DInterface和MInterface分别是data_interface.pymodel_interface.py中创建的类,他们两个分别就是

  • class DInterface(pl.LightningDataModule): 用于所有数据集的接口,在setup()方法中初始化你准备好的xxxdataset1.py,xxxdataset2.py中定义的torch.utils.data.Dataset类。在train_dataloader,val_dataloader,test_dataloader这几个方法中载入Dataloader即可。
  • class MInterface(pl.LightningModule): 用作模型的接口,在__init__()函数中import你准备好的xxxmodel2.py,xxxmodel1.py这些模型。重写training_step方法,validation_step方法,configure_optimizers方法。

  当大家在更改模型的时候只需要在对应的模块上进行更改即可,最后train.py主要功能就是读取参数,和调用dataModule和modelModule这两个包进行实例化DInterface和MInterface,当然一些PL框架的回调函数也需要在train.py里进行定义。

二、框架基本模块(Module)

2.1 LightningModule

  LightningModule必须包含的部分是init()和training_step(self, batch, batch_idx),其中init()主要是进行模型的初始化和定义(不需要定义数据集等)。training_step(...)主要是进行定义每个batch数据的处理函数。

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

  这里的batch就是从dataloader中出来的一个batch的数据,类似于for batch in dataloader :。可以看到这个函数的返回值是一个loss,文档中有提到这个方法一定要返回一个loss,如果是返回一个字典,那么必须包含“loss”这个键。当然也有例外,当我们重写training_step_end()方法的时候就不用training_step必须返回一个loss了,此时可以返回任意的东西,但是要注意training_step_end()方法就必须返回一个loss,至于training_step_end()主要是在使用多GPU训练的时候需要重写该方法,主要是进行损失的汇总。

  除了training_step,我们还有validation_step,test_step,其中test_step不会在训练中调用,而validation_step则是对测试数据进行模型推理,一般在这个步骤里可以用self.log进行记录某些值,例如:

def validation_step(self, batch, batch_idx): 
    pre = model(batch)
    loss = self.lossfun(...)
    # log记录
    self.log('val_loss',loss, on_epoch=True, prog_bar=True, logger=True)

  上面的使用的self.log是非常重要的一个方法,这个方法继承自LightningModule这个父类,我们使用这里log就可以在训练时使用ModelCheckpoint对象(用于保存模型的参数对象)去检测测试步骤中的参数(比如这里我们就可以检测val_loss这个值,来确定是否保存这个模型参数)

 

  self.log()中常用参数以下:

  • prog_bar:如果是True,该值将会显示在进度条上
  • logger:如果是True,将会记录到logger器中(会显示在tensorboard上)

2.2 LightningDataModule

  这一个类必须包含的部分是setup(self, stage=None)方法,train_dataloader()方法。

  • setup(self, stage=None):主要是进行Dataset的实例化,包括但不限于进行数据集的划分,划分成训练集和测试集,一般来说都是Dataset类
  • train_dataloader():很简单,只需要返回一个DataLoader类即可。
  • val_dataloader():与train_dataloader一样,用于初始化DataLoader对象,并返回

  有些时候也会定义collate_fn函数,在DataLoader创建时传入collate_fn参数,用于对Dataset进处理(但实际上一般是在Dataset类中定义,根据Dataset的属性类个性化配置)。

 三、常见问题

 

 

 

参考网站:

pytorch_lightning 全程笔记 - 知乎 (zhihu.com)

Pytorch Lightning 完全攻略 - 知乎 (zhihu.com)

PyTorch Lightning初步教程(上) - 知乎 (zhihu.com)

 “简约版”Pytorch —— Pytorch-Lightning详解_@YangZai的博客-CSDN博客

 201024-5步PyTorchLightning中设置并访问tensorboard_专注机器学习之路-CSDN博客

PyTorch Ligntning】快速上手简明指南_闻韶-CSDN博客

PyTorch Lightning 工具学习 - 知乎 (zhihu.com)

 

posted @ 2021-12-26 16:04  Circle_Wang  阅读(3605)  评论(0编辑  收藏  举报