Pytorch 基本流程记录

基本流程

数据

继承torch.utils.data.dataset

__init__(): 记录数据的路径

__getitem__(): 读取图像,进行预处理,转化成Tensor,返回。

__len__(): 记录长度

网络

继承nn.Module

__init__(): 设计网络的各个模块

forward(): 搭建网络,设计输出

训练

构建网络,加载到gpu

通过dataloader读入训练数据

将数据输入网络得到网络返回结果

设计损失函数,选择优化器

根据网络返回结果计算损失,损失反向传播计算梯度

优化器根据梯度更新参数,清空梯度

打印日志,保留中间结果,保存模型

测试

构建网络,加载到gpu

通过dataloader读入测试数据

将数据输入网络得到返回结果

对结果进行后处理

保存结果,或直接和Ground Truth计算评价指标

 

posted @ 2019-12-22 16:21  这是一个ID  阅读(381)  评论(0编辑  收藏  举报