代码笔记19 pytorch报错, loaded state dict contains a parameter group that doesn't match the size of optimizer's group
问题
重新加载训练的时候出现了这样的问题。我的训练步骤是加载resnet,然后冻结resnet参数,解冻resnet中的bn层(为什么解冻我前面有博客),保存模型。之后再重新加载
解决
根据pytorch论坛上的一个帖子[1],然后根据他的说法是,optimizer的定义放在了参数冻结之前,导致所有参数都被传进去了
举个例子
optimizer = torch.optim.SGD(params=filter(lambda p: p.requires_grad, net.parameters()),
lr=args.learning_rate,
weight_decay=args.weight_decay,
momentum=args.momentum)
如果在冻结前定义optimizer,加载的pth文件中的参数(貌似torch.save并不能保存参数的requires_grad属性),导致filter并没有起到作用,因为所有的参数刚被加载进网络的时requires_grad属性均为True,这个时候再去冻结参数,等于说无法再改变被定义的optimizer中的param了,所以optimizer要放在冻结参数之后
if bool(load_dir) == False:
print('no model in {}'.format(load_dir))
os._exit(0)
net._load_resnet_pretrained()
net._freeze_parameters()
optimizer = torch.optim.SGD(params=filter(lambda p: p.requires_grad, net.parameters()),
lr=args.learning_rate,
weight_decay=args.weight_decay,
momentum=args.momentum)
pre_epoch = load_checkpoint(model=net, optimizer=optimizer, model_file=load_dir)
# optimizer.param_groups[0]['lr'] = 0.000005
pre_epoch += 1
if device == 'cuda':
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
万事大吉
Refrences
[1] https://discuss.pytorch.org/t/error-when-loading-adam-optimizer-resume-training/50870