BN

BN

BN中有一些比较值得注意的地方:

  1. train/test不一致的好处与坏处
  2. 推理中的坑:移动平均。
  3. 训练中的坑:batch的大小与分布。
  4. 微调中的坑:参数化,数据分布等。
  5. 实现中的坑:一个多功能的BN的实现。
  6. GN,precise-BN等等改进。

BN在训练和测试的时候,行为是不一致的。

在训练的时候,BN是使用了EMA来进行更新的。在测试的时候,并不是采用了EMA,而是采用了训练时候的统计量。

  1. EMA在λ过于小的时候,EMA并不是合理的近似。
  2. λ过于大的时候,需要很多次迭代。
  3. 模型不稳定的时候,或者是数据不稳定的时候。可能造成一些问题。

使用Precise-BatchNorm

继续使用EMA,但是使用比较大的λ,把模型固定住。forward很多次迭代。

Rethinking 'Batch' in batchnormalization这篇paper没怎么读。但是我读了一下precise BN的code:

为了防止大家对里面的一些函数并不是很熟悉,所以。

itertools.islice()表示对迭代器进行切片,并且会消耗迭代器。

running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)

这个其实很好理解。这个等价于先求和再取平均。

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import itertools

import torch

BN_MODULE_TYPES = (
    torch.nn.BatchNorm1d,
    torch.nn.BatchNorm2d,
    torch.nn.BatchNorm3d,
    torch.nn.SyncBatchNorm,
)


@torch.no_grad()
def update_bn_stats(model, data_loader, num_iters: int = 200):
    """
    Recompute and update the batch norm stats to make them more precise. During
    training both BN stats and the weight are changing after every iteration, so
    the running average can not precisely reflect the actual stats of the
    current model.
    In this function, the BN stats are recomputed with fixed weights, to make
    the running average more precise. Specifically, it computes the true average
    of per-batch mean/variance instead of the running average.

    Args:
        model (nn.Module): the model whose bn stats will be recomputed.

            Note that:

            1. This function will not alter the training mode of the given model.
               Users are responsible for setting the layers that needs
               precise-BN to training mode, prior to calling this function.

            2. Be careful if your models contain other stateful layers in
               addition to BN, i.e. layers whose state can change in forward
               iterations.  This function will alter their state. If you wish
               them unchanged, you need to either pass in a submodule without
               those layers, or backup the states.
        data_loader (iterator): an iterator. Produce data as inputs to the model.
        num_iters (int): number of iterations to compute the stats.
    """
    bn_layers = get_bn_modules(model)

    if len(bn_layers) == 0:
        return

    # In order to make the running stats only reflect the current batch, the
    # momentum is disabled.
    # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
    # Setting the momentum to 1.0 to compute the stats without momentum.
    momentum_actual = [bn.momentum for bn in bn_layers]
    for bn in bn_layers:
        bn.momentum = 1.0

    # Note that running_var actually means "running average of variance"
    running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
    running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]

    for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
        model(inputs)

        for i, bn in enumerate(bn_layers):
            # Accumulates the bn stats.
            running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
            running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
            # We compute the "average of variance" across iterations.
    assert ind == num_iters - 1, (
        "update_bn_stats is meant to run for {} iterations, "
        "but the dataloader stops at {} iterations.".format(num_iters, ind)
    )

    for i, bn in enumerate(bn_layers):
        # Sets the precise bn stats.
        bn.running_mean = running_mean[i]
        bn.running_var = running_var[i]
        bn.momentum = momentum_actual[i]


def get_bn_modules(model):
    """
    Find all BatchNorm (BN) modules that are in training mode. See
    cvpack2.modeling.nn_utils.precise_bn.BN_MODULE_TYPES for a list of all modules that are
    included in this search.

    Args:
        model (nn.Module): a model possibly containing BN modules.

    Returns:
        list[nn.Module]: all BN modules in the model.
    """
    # Finds all the bn layers.
    bn_layers = [
        m
        for m in model.modules()
        if m.training and isinstance(m, BN_MODULE_TYPES)
    ]
    return bn_layers

posted @   John_Ran  阅读(575)  评论(0编辑  收藏  举报
编辑推荐:
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~
点击右上角即可分享
微信分享提示