修改模型参数名
import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models from model import TTSR from option import args # model = TTSR.TTSR(args) # print(model) model_path ='./pretrain_model/model_00118.pt' ori_model_dict = torch.load(model_path) model_state_dict = {key.replace('MainNet.b_1.weight', 'MainNet.b_mod.1.weight'): ori_model_dict[key] for key in ori_model_dict }#将MainNet.b_1.weight修改为MainNet.b_mod.1.weight model_state_dict = {key.replace('MainNet.b_1.bias', 'MainNet.b_mod.1.bias'): model_state_dict[key] for key in model_state_dict } model_state_dict = {key.replace('MainNet.b_4.weight', 'MainNet.b_mod.4.weight'): model_state_dict[key] for key in model_state_dict } model_state_dict = {key.replace('MainNet.b_4.bias', 'MainNet.b_mod.4.bias'): model_state_dict[key] for key in model_state_dict } model_state_dict = {key.replace('MainNet.b_6.weight', 'MainNet.b_mod.6.weight'):model_state_dict[key] for key in model_state_dict } model_state_dict = {key.replace('MainNet.b_6.bias', 'MainNet.b_mod.6.bias'):model_state_dict[key] for key in model_state_dict } model_state_dict = {key.replace('MainNet.b_8.weight', 'MainNet.b_mod.8.weight'): model_state_dict[key] for key in model_state_dict } model_state_dict = {key.replace('MainNet.b_8.bias', 'MainNet.b_mod.8.bias'): model_state_dict[key] for key in model_state_dict } torch.save(model_state_dict,'now.pt') #加载打印新的模型参数 model_path_now ='./now.pt' now_model_dict = torch.load(model_path_now) print(now_model_dict['MainNet.b_6.weight']) print(now_model_dict['MainNet.b_6.bias'])