模型保存一般是在epoch结束时保存一次,但这样会存在随机性,可能epoch结束波动比较大,通过在单个epoch中平均多次提升模型鲁棒性
以下为 model_avg 计算过程,通过直接修改 model_avg.state_dict()
中storage
部分值对应的地址,达到修改model_avg目的
def avg_state_dict(
state_dict_cur: Dict[str, Tensor],
state_dict_avg: Dict[str, Tensor],
batch_idx_train,
average_period,
scaling_factor= 6.6
) -> Dict[str, Tensor]:
weight_avg = average_period / batch_idx_train
weight_cur = 1 - weight_avg
uniqued: Dict[int, str] = dict()
for k, v in state_dict_avg.items():
# if k == 'encoder.encoders.4.conv_module.pointwise_conv1.bias':
# print( v )
v_data_ptr = v.data_ptr() # 获取数据的地址
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for k in uniqued_names:
if state_dict_avg[k].dtype == torch.int64:
state_dict_avg[k] = state_dict_cur[k].to(device=state_dict_avg[k].device)
else:
state_dict_avg[k] *= weight_avg
state_dict_avg[k] += (
state_dict_cur[k].to(device=state_dict_avg[k].device) * weight_cur
)
state_dict_avg[k] *= scaling_factor
常规做法:计算后回传
def avg_state_dict(
state_dict_cur: Dict[str, Tensor],
state_dict_avg: Dict[str, Tensor],
batch_idx_train,
average_period,
scaling_factor= 6.6
) -> Dict[str, Tensor]:
...
return new_state_dict
load_state_dict(model_avg , avg_state_dict() )