PSPNet 代码分析
https://www.lmlphp.com/user/62501/article/item/1225624/
train.py
网络训练主函数,主要操作有:
- 传入训练参数;通常采用argparse库,支持脚本传入。
- 网络训练;包括定义网络、加载模型、前向反向传播、保存模型等。
- 将训练情况可视化;使用tensorboard绘制loss曲线。
datasets.py
在pytorch中数据加载到模型的操作顺序如下:
- 创建一个Dataset对象,一般重载
__len__
和__getitem__
方法。__len__
返回数据集大小,__getitem__
支持索引,以便Dataset[i]获取第i个样本。 - 创建一个DataLoader对象,将Dataset作为参数传入。
- 循环这个DataLoader对象,将img、label加载到模型中进行训练。
我们还需在Dataset对象中定义数据预处理,这里采用:
-
0.7-1.4倍的随机尺度缩放
-
各通道减去ImageNet的均值
-
随机crop下769x769大小
-
镜像随机翻转
注意:为了让Image和Label对应,也要对Label作相应的预处理,具体过程详见代码。