网络训练至某个epoch,参数 问题

1 start_epoch = params.start_epoch
2   stop_epoch = params.stop_epoch
3   if params.resume != '':
4     resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)  #get_resume_file函数得到epoch.tar文件
5     if resume_file is not None:
6       tmp = torch.load(resume_file)
7       start_epoch = tmp['epoch']+1
8       model.load_state_dict(tmp['state'])
9       print('  resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))

对tar文件进行加载,并且选取其中需要的字典权重

 1 tmp = torch.load(modelfile)   # load parameter file:400.tar
 2   try:
 3     state = tmp['state']
 4   except KeyError:
 5     state = tmp['model_state']
 6   except:
 7     raise
 8   state_keys = list(state.keys())  #列举字典中的key
 9   for i, key in enumerate(state_keys):
10     if "feature." in key and not 'gamma' in key and not 'beta' in key:
11       newkey = key.replace("feature.","")
12       state[newkey] = state.pop(key)  #删除该key并返回对应的值,不影响上面的训练
13     else:
14       state.pop(key)
15 
16   model.load_state_dict(state) 

 

posted on 2020-07-30 14:48  Yxter  阅读(695)  评论(0编辑  收藏  举报

导航