PSPNet 代码分析

https://www.lmlphp.com/user/62501/article/item/1225624/

 

train.py

网络训练主函数,主要操作有:

  1. 传入训练参数;通常采用argparse库,支持脚本传入。
  2. 网络训练;包括定义网络、加载模型、前向反向传播、保存模型等。
  3. 将训练情况可视化;使用tensorboard绘制loss曲线。

datasets.py

在pytorch中数据加载到模型的操作顺序如下:

  1. 创建一个Dataset对象,一般重载__len____getitem__方法。__len__返回数据集大小,__getitem__支持索引,以便Dataset[i]获取第i个样本。
  2. 创建一个DataLoader对象,将Dataset作为参数传入。
  3. 循环这个DataLoader对象,将img、label加载到模型中进行训练。

我们还需在Dataset对象中定义数据预处理,这里采用:

  1. 0.7-1.4倍的随机尺度缩放

  2. 各通道减去ImageNet的均值

  3. 随机crop下769x769大小

  4. 镜像随机翻转

注意:为了让Image和Label对应,也要对Label作相应的预处理,具体过程详见代码。

posted @ 2022-08-02 12:12  ethan178  Views(141)  Comments(0Edit  收藏  举报