权重pth读取更换其值重新保存方法(附带参数计算)

本文主要解决模型权重迁移,主要使用pytorch读取某个权重,将其赋值给新权重格式,以下为原始代码:

 

顺带参数计算函数代码:

参数计算:

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

 

 权重更改代码如下:

if __name__ == '__main__':
    train_pth_root=r'D:\Users\User\Desktop\mask-try\epoch_100.pth'  # 模型训练后得到的权重 
    pre_pth_root = r'D:\Users\User\Desktop\mask-try\resnet50-19c8e357.pth'  # 原始预训练权重,如mmdet的resnet预训练权重
    train_net=torch.load(train_pth_root)
    net_state_dict = train_net['state_dict']  # 训练模型权重保存字典键值
    pre_net=torch.load(pre_pth_root)
    # 以下替换和更改成预训练权重格式,这里需要根据具体情况决定,本代码是基于mmdection修改的
    keys_lst=[k.replace('backbone.','') for k in net_state_dict.keys() if 'backbone.' in k]
    for k,v in pre_net.items():
        if k in keys_lst:
            k_new='backbone.'+k
            pre_net[k]=net_state_dict[k_new]
    # 保存新权重
    torch.save(pre_net,'D:/Users/User/Desktop/mask-try/fasterrcnn_adaw.pth')

 

 

 

posted @ 2022-03-15 16:33  tangjunjun  阅读(929)  评论(0编辑  收藏  举报
https://rpc.cnblogs.com/metaweblog/tangjunjun