深度学习(归一化)

在深度学习中,归一化操作有BN,LN,GN,IN这几种形式,下表给出了各种方法的主要区别:

归一化方法 计算维度 固定维度 适用场景 特点
BatchNorm  沿 (N, H, W) 对每个通道独立计算 Channel 卷积神经网络 依赖批次大小,训练和推理行为不同
LayerNorm  沿 (C, H, W) 对每个样本独立计算  Batch  NLP、Transformer 对每个样本的所有特征归一化
GroupNorm  沿 (H, W) 对每个样本的每组通道计算  Batch 小批次或动态批次 不依赖批次大小,将通道分组归一化
InstanceNorm  沿 (H, W) 对每个样本的每个通道计算 Batch&Channel 风格迁移、生成模型 对每个样本的每个通道归一化

下面是实现代码,并且和pytorch结果做了比较:

import torch
import torch.nn as nn

class BatchNormalizationModel(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(BatchNormalizationModel, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=(0, 2, 3), keepdim=True)
        var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)

class LayerNormalizationModel(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(LayerNormalizationModel, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features)) 
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)         
        var = x.var(dim=-1, unbiased=False, keepdim=True) 
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_hat + self.beta 

class GroupNormalizationModel(nn.Module):
    def __init__(self, num_features, num_groups, eps=1e-5):
        super(GroupNormalizationModel, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.num_groups = num_groups
        self.eps = eps

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.num_groups, -1)
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        x_hat = x_hat.view(N, C, H, W)
        return self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)


class InstanceNormalizationModel(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(InstanceNormalizationModel, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.eps = eps

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N * C, -1)
        mean = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, unbiased=False, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        x_hat = x_hat.view(N, C, H, W)
        return self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)


class NormalizationModel(nn.Module):
    def __init__(self):
        super(NormalizationModel, self).__init__()
        self.bn = nn.BatchNorm2d(num_features=10)
        self.ln = nn.LayerNorm(normalized_shape=1000)
        self.gn = nn.GroupNorm(num_groups=2, num_channels=10)
        self.in_norm = nn.InstanceNorm2d(num_features=10)

    def forward(self, x_bn, x_ln, x_gn, x_in):
        out_bn = self.bn(x_bn)
        out_ln = self.ln(x_ln)
        out_gn = self.gn(x_gn)
        out_in = self.in_norm(x_in)
        return out_bn, out_ln, out_gn, out_in


x_bn = torch.randn(20, 10, 50, 50)  # BatchNorm 输入
x_ln = torch.randn(20, 10, 1000)    # LayerNorm 输入
x_gn = torch.randn(20, 10, 50, 50)  # GroupNorm 输入
x_in = torch.randn(20, 10, 50, 50)  # InstanceNorm 输入

bn_model = BatchNormalizationModel(num_features=10)
ln_model = LayerNormalizationModel(num_features=1000)
gn_model = GroupNormalizationModel(num_features=10, num_groups=2)
in_model = InstanceNormalizationModel(num_features=10)
model = NormalizationModel()

bn_output = bn_model(x_bn)
ln_output = ln_model(x_ln)
gn_output = gn_model(x_gn)
in_output = in_model(x_in)

out_bn, out_ln, out_gn, out_in = model(x_bn, x_ln, x_gn, x_in)

print(torch.allclose(bn_output, out_bn, atol=1e-6))
print(torch.allclose(ln_output, out_ln, atol=1e-6))
print(torch.allclose(gn_output, out_gn, atol=1e-6))
print(torch.allclose(in_output, out_in, atol=1e-6))
posted @ 2025-01-28 14:36  Dsp Tian  阅读(55)  评论(0)    收藏  举报