Fork me on GitHub

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

另辟蹊径的模型(pytorch)修改方式

模型保存一般是在epoch结束时保存一次,但这样会存在随机性,可能epoch结束波动比较大,通过在单个epoch中平均多次提升模型鲁棒性

以下为 model_avg 计算过程,通过直接修改 model_avg.state_dict()storage部分值对应的地址,达到修改model_avg目的

def avg_state_dict(
        state_dict_cur: Dict[str, Tensor],
        state_dict_avg: Dict[str, Tensor],
        batch_idx_train,
        average_period,
        scaling_factor= 6.6
) -> Dict[str, Tensor]:
    weight_avg = average_period / batch_idx_train
    weight_cur = 1 - weight_avg
    uniqued: Dict[int, str] = dict()
    for k, v in state_dict_avg.items():
        # if k == 'encoder.encoders.4.conv_module.pointwise_conv1.bias':
        #     print( v )
        v_data_ptr = v.data_ptr()  # 获取数据的地址
        if v_data_ptr in uniqued:
            continue
        uniqued[v_data_ptr] = k
    uniqued_names = list(uniqued.values())

    for k in uniqued_names:
        if state_dict_avg[k].dtype == torch.int64:
            state_dict_avg[k] = state_dict_cur[k].to(device=state_dict_avg[k].device)
        else:
            state_dict_avg[k] *= weight_avg
            state_dict_avg[k] += (
                    state_dict_cur[k].to(device=state_dict_avg[k].device) * weight_cur
            )
            state_dict_avg[k] *= scaling_factor

常规做法:计算后回传

def avg_state_dict(
        state_dict_cur: Dict[str, Tensor],
        state_dict_avg: Dict[str, Tensor],
        batch_idx_train,
        average_period,
        scaling_factor= 6.6
) -> Dict[str, Tensor]:
 ...
 return new_state_dict

load_state_dict(model_avg , avg_state_dict() )

posted @ 2022-10-12 17:57  365/24/60  阅读(41)  评论(0编辑  收藏  举报