【猫狗数据集】从命令行接收参数
数据集下载地址:
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
提取码:2xq4
创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html
读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html
进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html
保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html
加载保存的模型并测试:https://www.cnblogs.com/xiximayou/p/12459499.html
划分验证集并边训练边验证:https://www.cnblogs.com/xiximayou/p/12464738.html
使用学习率衰减策略并边训练边测试:https://www.cnblogs.com/xiximayou/p/12468010.html
利用tensorboard可视化训练和测试过程:https://www.cnblogs.com/xiximayou/p/12482573.html
epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html
本节我们要在命令行接收参数,包括batch_size的值以及网络的类型。
基本上我们只需要修改main.py就行了:
main.py
import sys sys.path.append("/content/drive/My Drive/colab notebooks") from utils import rdata from model import resnet import torch.nn as nn import torch import numpy as np import torchvision import train import torch.optim as optim np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed_all(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def main(batch_size,baseline): train_loader,val_loader,test_loader=rdata.load_dataset(batch_size) if baseline: model =torchvision.models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features,2,bias=False) if torch.cuda.is_available(): model.cuda() #定义训练的epochs num_epochs=100 #定义学习率 learning_rate=0.1 #定义损失函数 criterion=nn.CrossEntropyLoss() #定义优化方法,简单起见,就是用带动量的随机梯度下降 optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, momentum=0.9, weight_decay=1*1e-4) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [40,80], 0.1) print("训练集有:",len(train_loader.dataset)) #print("验证集有:",len(val_loader.dataset)) print("测试集有:",len(test_loader.dataset)) trainer=train.Trainer(criterion,optimizer,model) trainer.loop(num_epochs,train_loader,val_loader,test_loader,scheduler) if __name__ == "__main__": import argparse p=argparse.ArgumentParser() p.add_argument("--batch_size",type=int,default=64) p.add_argument("--baseline",action="store_true") args=p.parse_args() main(args.batch_size,args.baseline)
说明:我们将读取数据集、定义损失、优化器等代码放入到main()函数中,然后给main传入batch_size和baseline。使用argparse可以从命令行接收参数。add_argument()函数中,第一个参数是参数的名称,第二个是参数的类型,default是默认值,即不在命令行输入--batch_size 具体值,则会使用默认值。需要关注的是action="store_true",该参数的意思是默认baseline为False,如果在命令行中加入了--baseline,则baseline的值就为True。
结果如图所示:
没有加--batch_size,则batch_size默认为64,也就是18255/64约等于286。然后我们使用了--baseline,即默认使用resnet18模型。
由于图像分类一般考虑的衡量指标是top1和top5,下一节就是加上计算top5的代码了。