Pytorch 基本流程记录
基本流程
数据
继承torch.utils.data.dataset
__init__(): 记录数据的路径
__getitem__(): 读取图像,进行预处理,转化成Tensor,返回。
__len__(): 记录长度
网络
继承nn.Module
__init__(): 设计网络的各个模块
forward(): 搭建网络,设计输出
训练
构建网络,加载到gpu
通过dataloader读入训练数据
将数据输入网络得到网络返回结果
设计损失函数,选择优化器
根据网络返回结果计算损失,损失反向传播计算梯度
优化器根据梯度更新参数,清空梯度
打印日志,保留中间结果,保存模型
测试
构建网络,加载到gpu
通过dataloader读入测试数据
将数据输入网络得到返回结果
对结果进行后处理
保存结果,或直接和Ground Truth计算评价指标