Fork me on GitHub

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

另辟蹊径的模型(pytorch)修改方式

模型保存一般是在epoch结束时保存一次,但这样会存在随机性,可能epoch结束波动比较大,通过在单个epoch中平均多次提升模型鲁棒性

以下为 model_avg 计算过程,通过直接修改 model_avg.state_dict()storage部分值对应的地址,达到修改model_avg目的

def avg_state_dict(
        state_dict_cur: Dict[str, Tensor],
        state_dict_avg: Dict[str, Tensor],
        batch_idx_train,
        average_period,
        scaling_factor= 6.6
) -> Dict[str, Tensor]:
    weight_avg = average_period / batch_idx_train
    weight_cur = 1 - weight_avg
    uniqued: Dict[int, str] = dict()
    for k, v in state_dict_avg.items():
        # if k == 'encoder.encoders.4.conv_module.pointwise_conv1.bias':
        #     print( v )
        v_data_ptr = v.data_ptr()  # 获取数据的地址
        if v_data_ptr in uniqued:
            continue
        uniqued[v_data_ptr] = k
    uniqued_names = list(uniqued.values())

    for k in uniqued_names:
        if state_dict_avg[k].dtype == torch.int64:
            state_dict_avg[k] = state_dict_cur[k].to(device=state_dict_avg[k].device)
        else:
            state_dict_avg[k] *= weight_avg
            state_dict_avg[k] += (
                    state_dict_cur[k].to(device=state_dict_avg[k].device) * weight_cur
            )
            state_dict_avg[k] *= scaling_factor

常规做法:计算后回传

def avg_state_dict(
        state_dict_cur: Dict[str, Tensor],
        state_dict_avg: Dict[str, Tensor],
        batch_idx_train,
        average_period,
        scaling_factor= 6.6
) -> Dict[str, Tensor]:
 ...
 return new_state_dict

load_state_dict(model_avg , avg_state_dict() )

posted @   365/24/60  阅读(44)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示