深度之眼PyTorch训练营第二期---16、模型保存与加载

一、序列化与反序列化

 

 1、torch.save

主要参数:

  • obj:对象
  • f:输出路径

2、torch.load

主要参数:

  • f:文件路径
  • map_location:指定存放位置,cpu or gpu

 

二、模型保存与加载的两种方式

第一种方式:

  保存整个Module

  torch.save(net,path)

第二种方式:

  state_dict = net.state_dict()

   torch.save(state_dict , path)

三、模型断点续训练

 

模型微调Finetune

四、Transfer Learning & Model Finetune

Transfer Learning:机器学习分支,研究源域(source domain)的知识如何应用到目标域(target domain)

Model Finetune:模型的迁移学习

 

模型微调的步骤:

1、获取预训练模型参数

2、加载模型(load_state_dict)

3、修改输出层

 

模型微调训练方法

1、固定预训练的参数(requires_grad = False; lr = 0)

2、Features Extractor较小学习率(params_group)

五、PyTorch中的Finetune

Finetune Resnet-18 用于二分类

蚂蚁蜜蜂二分类数据

训练集:各120~张  验证集:各70~张

 

posted @ 2019-11-26 22:24  cola_cola  阅读(228)  评论(0编辑  收藏  举报