模型保存一般是在epoch结束时保存一次,但这样会存在随机性,可能epoch结束波动比较大,通过在单个epoch中平均多次提升模型鲁棒性
以下为 model_avg 计算过程,通过直接修改 model_avg.state_dict()
中storage
部分值对应的地址,达到修改model_avg目的
def avg_state_dict(
state_dict_cur: Dict[str, Tensor],
state_dict_avg: Dict[str, Tensor],
batch_idx_train,
average_period,
scaling_factor= 6.6
) -> Dict[str, Tensor]:
weight_avg = average_period / batch_idx_train
weight_cur = 1 - weight_avg
uniqued: Dict[int, str] = dict()
for k, v in state_dict_avg.items():
# if k == 'encoder.encoders.4.conv_module.pointwise_conv1.bias':
# print( v )
v_data_ptr = v.data_ptr() # 获取数据的地址
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for k in uniqued_names:
if state_dict_avg[k].dtype == torch.int64:
state_dict_avg[k] = state_dict_cur[k].to(device=state_dict_avg[k].device)
else:
state_dict_avg[k] *= weight_avg
state_dict_avg[k] += (
state_dict_cur[k].to(device=state_dict_avg[k].device) * weight_cur
)
state_dict_avg[k] *= scaling_factor
常规做法:计算后回传
def avg_state_dict(
state_dict_cur: Dict[str, Tensor],
state_dict_avg: Dict[str, Tensor],
batch_idx_train,
average_period,
scaling_factor= 6.6
) -> Dict[str, Tensor]:
...
return new_state_dict
load_state_dict(model_avg , avg_state_dict() )
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人