模型保存、加载与转换
转载:https://www.zhihu.com/question/319682377
一 pytorch模型保存与加载
模型包括模型结构和模型参数,因此,有不同的保存模型的方式。
1 官方推荐
1.1 只保存模型参数
优点:只保存模型参数,该方法速度快,占用的空间少
缺点:重载模型时,必须知道模型的结构
保存:
model = VGGNet() torch.save(model.state_dict(), PATH) #存储model中的参数
加载:
new_model = VGGNet() #建立新模型 new_model.load_state_dict(torch.load(PATH)) #将model中的参数加载到new_model中
1.2 保存整个模型
该方法同时保存了模型结构和参数。
优点:重载时,不需要事先知道模型结构
缺点:1)模型占用资源较大,速度慢;2)save 的时候是将整个模型进行了保存,所以会需要在相同的目录结构下,这样能够得到静态模型,也就能依据静态模型进行load,所以load的时候必须存在原来的目录结构,否则会报错。
保存:
model = VGGNet() torch.save(model, PATH) #存储整个模型
加载:
new_model = torch.load(PATH) #将整个model加载到new_model中 #new_model 不再需要第一种方法中的建立新模型的步骤
保存文件格式说明
关于上面表达式中PATH参数的说明:
PATH参数是你保存文件的路径,并且需要指定保存文件的文件名,如:
torch.save(model, '/home/user/save_model/checkpoint.pth')
即将该模型保存在/home/user/save_model路径下的checkpoint.pth文件中,保存的文件格式约定为.pth或.pt
new_model = torch.load('/home/user/save_model/checkpoint.pth')
但是在pytorch1.6版本中,
torch.save存储的文件格式采用了新的基于压缩文件的格式 .pth.tar
torch.load依然保留了加载了旧格式.pth的能力
2 保存checkpoint(检查点)
通常在训练模型的过程中,可能会遭遇断电、断网的尴尬,一旦出现这种情况,先前训练的模型就白费了,又得重头开始训练。因此每隔一段时间就将训练模型信息保存一次很有必要。而这些信息不光包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复训练。
state = { 'epoch' : epoch + 1, #保存当前的迭代次数 'state_dict' : model.state_dict(), #保存模型参数 'optimizer' : optimizer.state_dict(), #保存优化器参数 ..., #其余一些想保持的参数都可以添加进来 ..., } torch.save(state, 'checkpoint.pth.tar') #将state中的信息保存到checkpoint.pth.tar #Pytorch 约定使用.tar格式来保存这些检查点 #当想恢复训练时 checkpoint = torch.load('checkpoint.pth.tar') epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) #加载模型的参数 optimizer.load_state_dict(checkpoint['optimizer']) #加载优化器的参数
3 不同设备上的模型保存与加载
3.1 在CPU上存储,在GPU上加载
#CPU存储 torch.save(model.state_dict(), PATH) #GPU加载 device = torch.device('cuda') model = Model() model.load_state_dict(torch.load(PATH, map_location='cuda:0')) #可以选择任意GPU设备 model.to(device)
3.2 在GPU上存储,在CPU上加载
#GPU存储 torch.save(model.state_dict(), PATH) #CPU加载 device = torch.device('cpu') model = Model() model.load_state_dict(torch.load(PATH, map_location=device))
3.3 在GPU存储,在GPU加载
#GPU存储 torch.save(model.state_dict(), PATH) #GPU加载 device = torch.device('cuda') model = Model() model.load_state_dict(torch.load(PATH)) model.to(device)
3.4 存储和加载使用过torch.nn.DataParallel的模型
3.4.1 多卡训练,单卡加载部署
''' 这种情况要防止参数保存的时候没有加module,那么保存的参数名称是module.conv1.weight, 而单卡的参数名称是conv1.weight,这时就会报错,找不到相应的字典的错误。 此时可以通过手动的方式删减掉模型中前几位的名称,然后重新加载。 不懂代码可以先看一下第2部分内容模型参数存储内容解析 ''' model = torch.nn.DataParallel(model) #存储 torch.save(model.module.state_dict(), PATH) #加载 kwargs={'map_location':lambda storage, loc: storage.cuda(gpu_id)} def load_GPUS(model,model_path,kwargs): state_dict = torch.load(PATH, **kwargs) # create new OrderedDict that does not contain 'module.' from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' new_state_dict[name] = v # load params model.load_state_dict(new_state_dict) return model
3.4.2 单卡训练,多卡加载部署
''' 此时唯有记住一点,因为单卡训练参数是没有module的,而多卡加载的参数是有module的, 因此需要保证参数加载在模型分发之前。 ''' #存储 torch.save(model.state_dict(), PATH) #加载 model.load_state_dict(torch.load(PATH)) model = torch.nn.DataParallel(model) #模型分发
3.4.3 多卡训练多卡加载部署
环境如果没有变化,则可以直接加载,如果环境有变化,则可以拆解成第1种情况,然后再分发模型。
4 模型变更后的参数加载操作
当我们使用像resnet50、resnet101这样的网络时,通常可以从网上对应下载到这些模型的预训练参数文件,但是我们所使用的模型,可能需要在resnet50或resnet101网络上进行一些修改,比如增加一些结构,或者删除一些结构。所以我们只希望加载修改后的模型与原来的模型之间具有相同结构部分的参数。
#假设下载到的原有模型参数文件为checkpoint.pth.tar model = OurModel() model_checkpoint = torch.load('checkpoint.pth.tar') pretrain_model_dict = model_checkpoint['state_dict'] model_dict = model.state_dict() same_model_dict = {k : v for k, v in pretrain_model_dict if k in model_dict} model_dict.update(same_model_dict) model.load_state_dict(model_dict)