代码笔记19 pytorch报错, loaded state dict contains a parameter group that doesn't match the size of optimizer's group

问题

  重新加载训练的时候出现了这样的问题。我的训练步骤是加载resnet,然后冻结resnet参数,解冻resnet中的bn层(为什么解冻我前面有博客),保存模型。之后再重新加载

解决

  根据pytorch论坛上的一个帖子[1],然后根据他的说法是,optimizer的定义放在了参数冻结之前,导致所有参数都被传进去了
举个例子

    optimizer = torch.optim.SGD(params=filter(lambda p: p.requires_grad, net.parameters()),
                                lr=args.learning_rate,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)

如果在冻结前定义optimizer,加载的pth文件中的参数(貌似torch.save并不能保存参数的requires_grad属性),导致filter并没有起到作用,因为所有的参数刚被加载进网络的时requires_grad属性均为True,这个时候再去冻结参数,等于说无法再改变被定义的optimizer中的param了,所以optimizer要放在冻结参数之后

 if bool(load_dir) == False:
            print('no model in {}'.format(load_dir))
            os._exit(0)
        net._load_resnet_pretrained()
        net._freeze_parameters()
        optimizer = torch.optim.SGD(params=filter(lambda p: p.requires_grad, net.parameters()),
                                    lr=args.learning_rate,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)
        pre_epoch = load_checkpoint(model=net, optimizer=optimizer, model_file=load_dir)
        # optimizer.param_groups[0]['lr'] = 0.000005
        pre_epoch += 1
        if device == 'cuda':
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()

万事大吉

Refrences

[1] https://discuss.pytorch.org/t/error-when-loading-adam-optimizer-resume-training/50870

posted @ 2022-06-20 11:04  The1912  阅读(3642)  评论(0编辑  收藏  举报