pytorch-模型保存和加载

pytorch-模型保存和加载

加载模型参数和选择是由保存的模型数据结构决定,故先要确定保存模型模型的方法和数据结构

保存模型

# 模型权重参数
model.state_dict()
'''首先说一下 model.state_dict()
pytorch 中的 model.state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)
只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等

state_dict是在定义了model或optimizer之后pytorch自动生成的
'''
# model.state_dict() 其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(1, 2)
        self.linear2 = nn.Linear(2, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)

        return x


mode = Net()
print(mode.state_dict())
"""
OrderedDict([('linear1.weight', tensor([[ 0.8108],[-0.7968]])), ('linear1.bias', tensor([ 0.2680, -0.4772])), ('linear2.weight', tensor([[-0.7066, -0.3334]])), ('linear2.bias', tensor([0.4819]))])

"""

print(mode.state_dict().keys())
"""
odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
"""

for param_tensor in model.state_dict():
    #打印 key value字典
    print(param_tensor,'\t',model.state_dict()[param_tensor].size())
    
"""
linear1.weight 	 torch.Size([2, 1])
linear1.bias 	 torch.Size([2])
linear2.weight 	 torch.Size([1, 2])
linear2.bias 	 torch.Size([1])
"""
# 保存模型

torch.save(obj, f, pickle_module,pickle_protocol )
"""输入参数
obj	   			可以是单个值也可以字典、对象
f 	   			要保存参数的文件路径
pickle_module
pickle_protocol
"""

# 1、自定义保存-工程实践中常常使用---推荐
state = {'model':     model.state_dict(), 
         'optimizer': optimizer.state_dict(), 
         'epoch':     epoch   }
torch.save(model_object, './model.pt')  

# 2、仅仅是保存模型权重参数
torch.save(model.state_dict(), PATH)

# 3、直接保存整个模型和模型结构
torch.save(Net,PATH)

加载模型

参数的保存

torch.save(model_object.state_dict(), 'params.pth') 

# 模型的加载有模型保存的数据结构决定
ckpt = torch.load(f, map_location=None)
"""输入参数
f					file模型文件
map_location		torch.device, 动态地进行内存重映射,从不同的设备上读取文件
pickle_module 		用于unpickling元数据和对象的模块
pickle_load_args 	传递给pickle_module.load()

注释: 如果多块显卡,map_location={'cuda:0':"cuda:1"},指定在2号显卡,不使用1号显卡
返回参数 字典d
由加载文件定义
默认情况,dict_keys(['epoch', 'state_dict', 'optimizer', 'best_pred'])
"""

# 1、针对第一种保存模型的加载方式
# 加载模型
model=Net()										
# 加载模型参数
model_CKPT = torch.load(checkpoint_PATH) 
# 参数各个属性f
model.load_state_dict(model_CKPT['model'])  
optimizer.load_state_dict(model_CKPT['optimizer'])

# 2、针对第二种保存模型的加载方式
model=Net()									# 实例化网络
model_CKPT = torch.load(checkpoint_PATH)    # 加载模型参数
model.load_state_dict(model_CKPT)  

# 针对第三种保存整个模型的加载方式
model = torch.load(mode_PATH)

部分权重的加载

# 关键自定义函数

def intersect_dicts(da, db, exclude=()):
    """输入参数
    da (state_dict)			 加载权重的 state_dict 
    db (state_dict) 	 	 加载模型的 state_dict
    exclude (list)           不想要的权重 keys()
    
    返回参数
    加载的部分权重 (state_dict)
    """	
    '''
    print("exclude",exclude)
    for k, v in da.items():
        for x in exclude:
            if x in k:
                print('@ ',x ,k)
            if v.shape != db[k].shape:
                print('# ', x, k)
	'''
    
    return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}

案例

# 加载模型
model = Net()

# 加载权重
ckpt=torch.load(weights, map_location=device)
state_dict=ckpt.state_dict()
# state_dict 是一个字典 

# state_dict.keys()
# odict_keys(['0.model.0.conv.conv.weight', '0.model.0.conv.conv.bias', '0.model.1.conv.weight', .....])

# 权重取舍处理
state_dict=intersect_dicts(state_dict, model.state_dict(), exclude=exclude)

# 模型加载权重
model.load_state_dict(state_dict, strict=False)

# 最后可以输出加载了多少个
print('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))
# output >>> Transferred 498/506 items from yolov5m.pt
posted @ 2021-06-23 21:20  贝壳里的星海  阅读(1112)  评论(0编辑  收藏  举报