加载预训练模型修改类别数与不修改类别数

#不修改类别数
checkpoint = torch.load("./save_model_other/best_model_pre.pth") print(checkpoint.state_dict().keys()) model.load_state_dict(checkpoint.state_dict())

#修改类别数
''' # 模型参数加载函数 def transfer_state_dict(pretrained_dict, model_dict): state_dict = {} for k, v in pretrained_dict.state_dict().items(): if k in model_dict.state_dict().keys(): state_dict[k] = v else: print("Missing keys in state_dict: {}".format(k)) return state_dict checkpoint = torch.load("./save_model_other/best_model_pre.pth") state_dict = transfer_state_dict(checkpoint, model) del state_dict["model2.2.weight"] del state_dict["model2.2.bias"] model_dict = model.state_dict() model_dict.update(state_dict) model.load_state_dict(model_dict) ''' #print(checkpoint.state_dict().keys()) #model.load_state_dict(checkpoint.state_dict())

 

posted @ 2022-08-16 15:19  小丑_jk  阅读(208)  评论(0编辑  收藏  举报