PyTorch | 项目结构解析
在学习和使用深度学习框架时,复现现有项目代码是必经之路,也能加深对理论知识的理解,提高动手能力。本文参照相关博客整理项目常用组织方式,以及每部分功能,帮助更好的理解复现项目流程,文末提供分类示例项目。
1 项目组织
在做深度学习实验或项目时,为了得到最优的模型结果,中间往往需要很多次的尝试和修改。一般项目都包含以下几个部分:
- 模型定义
- 数据处理和加载
- 训练模型(Train&Validate)
- 训练过程的可视化
- 测试(Test/Inference)
另外程序在组织过程中还应该满足以下几个要求:
- 模型需具有高度可配置性,便于修改参数、修改模型,反复实验
- 代码应具有良好的组织结构,使人一目了然
- 代码应具有良好的说明,使其他人能够理解
2 项目结构
- checkpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
- data/:数据相关操作,包括数据预处理、dataset实现等
- models/:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
- utils/:可能用到的工具函数,在本次实验中主要是封装了可视化工具
- config.py:配置文件,所有可配置的变量都集中在此,并提供默认值
- main.py:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
- requirements.txt:程序依赖的第三方库
- README.md:提供程序的必要说明
3 解析
3.1 __init__
- __init__ 可以为空,也可以定义包的属性和方法,但必须存在,其他程序才能从这个目录中读取模块和函数
3.2 数据加载
使用Dataset提供数据集的封装,再使用Dataloader实现数据并行加载。
- def __init__(self..)
获取图片地址,并根据训练、验证和测试划分数据
- def __getitem__(self, index):
返回图片的数据和label
- def __len__(self):
返回数据集数量
train_dataset = DogCat(opt.train_data_root, train=True) trainloader = DataLoader(train_dataset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers) for ii, (data, label) in enumerate(trainloader): train()
3.3 模型定义
型的定义主要保存在models/目录下,其中BasicModule是对nn.Module的简易封装,提供快速加载和保存模型的接口。
nn.Module主要包括save和load两个方法
from models import AlexNet
关于模型定义:
- 尽量使用nn.Sequential(比如AlexNet)
- 将经常使用的结构封装成子Module(比如GoogLeNet的Inception结构,ResNet的Residual Block结构)
- 将重复且有规律性的结构,用函数生成(比如VGG的多种变体,ResNet多种变体都是由多个重复卷积层组成)
3.4 工具函数
可能会用到一些helper方法,这些方法可以统一放在utils/文件夹下,需要使用时再引入。在本例中主要是封装了可视化工具visdom的一些操作,
3.5 配置文件
可配置的参数主要包括:
数据集参数(文件路径、batch_size等)
训练参数(学习率、训练epoch等)
模型参数
在实际使用时,并不需要每次都修改config.py,只需要通过命令行传入所需参数,覆盖默认配置即可。
3.6 main函数
提到了fire
main中包括train、val、test、help等
训练的主要步骤如下:
- 定义网络
- 定义数据
- 定义损失函数和优化器
- 计算重要指标
- 开始训练
- 训练网络
- 可视化各种指标
- 计算在验证集上的指标
4 示例分类代码
#coding:utf8 from config import opt import os import torch as t import models from data.dataset import DogCat from torch.utils.data import DataLoader from torch.autograd import Variable from torchnet import meter from utils.visualize import Visualizer from tqdm import tqdm def test(**kwargs): opt.parse(kwargs) import ipdb; ipdb.set_trace() # configure model model = getattr(models, opt.model)().eval() if opt.load_model_path: model.load(opt.load_model_path) if opt.use_gpu: model.cuda() # data train_data = DogCat(opt.test_data_root,test=True) test_dataloader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers) results = [] for ii,(data,path) in enumerate(test_dataloader): input = t.autograd.Variable(data,volatile = True) if opt.use_gpu: input = input.cuda() score = model(input) probability = t.nn.functional.softmax(score)[:,0].data.tolist() # label = score.max(dim = 1)[1].data.tolist() batch_results = [(path_,probability_) for path_,probability_ in zip(path,probability) ] results += batch_results write_csv(results,opt.result_file) return results def write_csv(results,file_name): import csv with open(file_name,'w') as f: writer = csv.writer(f) writer.writerow(['id','label']) writer.writerows(results) def train(**kwargs): opt.parse(kwargs) vis = Visualizer(opt.env) # step1: configure model model = getattr(models, opt.model)() if opt.load_model_path: model.load(opt.load_model_path) if opt.use_gpu: model.cuda() # step2: data train_data = DogCat(opt.train_data_root,train=True) val_data = DogCat(opt.train_data_root,train=False) train_dataloader = DataLoader(train_data,opt.batch_size, shuffle=True,num_workers=opt.num_workers) val_dataloader = DataLoader(val_data,opt.batch_size, shuffle=False,num_workers=opt.num_workers) # step3: criterion and optimizer criterion = t.nn.CrossEntropyLoss() lr = opt.lr optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay) # step4: meters loss_meter = meter.AverageValueMeter() confusion_matrix = meter.ConfusionMeter(2) previous_loss = 1e100 # train for epoch in range(opt.max_epoch): loss_meter.reset() confusion_matrix.reset() for ii,(data,label) in tqdm(enumerate(train_dataloader),total=len(train_data)): # train model input = Variable(data) target = Variable(label) if opt.use_gpu: input = input.cuda() target = target.cuda() optimizer.zero_grad() score = model(input) loss = criterion(score,target) loss.backward() optimizer.step() # meters update and visualize loss_meter.add(loss.data[0]) confusion_matrix.add(score.data, target.data) if ii%opt.print_freq==opt.print_freq-1: vis.plot('loss', loss_meter.value()[0]) # 进入debug模式 if os.path.exists(opt.debug_file): import ipdb; ipdb.set_trace() model.save() # validate and visualize val_cm,val_accuracy = val(model,val_dataloader) vis.plot('val_accuracy',val_accuracy) vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format( epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr)) # update learning rate if loss_meter.value()[0] > previous_loss: lr = lr * opt.lr_decay # 第二种降低学习率的方法:不会有moment等信息的丢失 for param_group in optimizer.param_groups: param_group['lr'] = lr previous_loss = loss_meter.value()[0] def val(model,dataloader): ''' 计算模型在验证集上的准确率等信息 ''' model.eval() confusion_matrix = meter.ConfusionMeter(2) for ii, data in enumerate(dataloader): input, label = data val_input = Variable(input, volatile=True) val_label = Variable(label.type(t.LongTensor), volatile=True) if opt.use_gpu: val_input = val_input.cuda() val_label = val_label.cuda() score = model(val_input) confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor)) model.train() cm_value = confusion_matrix.value() accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum()) return confusion_matrix, accuracy def help(): ''' 打印帮助的信息: python file.py help ''' print(''' usage : python file.py <function> [--args=value] <function> := train | test | help example: python {0} train --env='env0701' --lr=0.01 python {0} test --dataset='path/to/dataset/root/' python {0} help avaiable args:'''.format(__file__)) from inspect import getsource source = (getsource(opt.__class__)) print(source) if __name__=='__main__': import fire fire.Fire()
参考:https://github.com/chenyuntc/pytorch-best-practice/blob/master/PyTorch%E5%AE%9E%E6%88%98%E6%8C%87%E5%8D%97.md