修改模型参数名


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from model import TTSR
from option import args
# model = TTSR.TTSR(args)
# print(model)
model_path  ='./pretrain_model/model_00118.pt'
ori_model_dict = torch.load(model_path)
model_state_dict = {key.replace('MainNet.b_1.weight', 'MainNet.b_mod.1.weight'): ori_model_dict[key] for key in ori_model_dict }#将MainNet.b_1.weight修改为MainNet.b_mod.1.weight
model_state_dict = {key.replace('MainNet.b_1.bias',  'MainNet.b_mod.1.bias'): model_state_dict[key] for key in model_state_dict }
model_state_dict = {key.replace('MainNet.b_4.weight', 'MainNet.b_mod.4.weight'): model_state_dict[key] for key in model_state_dict }
model_state_dict = {key.replace('MainNet.b_4.bias', 'MainNet.b_mod.4.bias'): model_state_dict[key] for key in model_state_dict }
model_state_dict = {key.replace('MainNet.b_6.weight', 'MainNet.b_mod.6.weight'):model_state_dict[key] for key in model_state_dict }
model_state_dict = {key.replace('MainNet.b_6.bias', 'MainNet.b_mod.6.bias'):model_state_dict[key] for key in model_state_dict }
model_state_dict = {key.replace('MainNet.b_8.weight', 'MainNet.b_mod.8.weight'): model_state_dict[key] for key in  model_state_dict }
model_state_dict = {key.replace('MainNet.b_8.bias', 'MainNet.b_mod.8.bias'): model_state_dict[key] for key in model_state_dict }
torch.save(model_state_dict,'now.pt')

#加载打印新的模型参数
model_path_now  ='./now.pt'
now_model_dict = torch.load(model_path_now)

print(now_model_dict['MainNet.b_6.weight'])
print(now_model_dict['MainNet.b_6.bias'])

 


 

posted on 2021-04-07 20:53  cltt  阅读(327)  评论(0编辑  收藏  举报

导航