模型输入参数加入norm_layer实现可选择norm方式

以下代码为例
来源

from torchvision import models
import torch.nn as nn
import torch
import functools

class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)
#此处网络结构
"""[Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), LeakyReLU(negative_slope=0.2, inplace=True), 
Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False), LeakyReLU(negative_slope=0.2, inplace=True),
 Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False), LeakyReLU(negative_slope=0.2, inplace=True),
 Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1)), InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False), LeakyReLU(negative_slope=0.2, inplace=True),
 Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))]"""

在模型初始化参数时加入norm_layer=nn.BatchNorm2d
下面use_bias参数判断是否需要使用bias(Batch不需要,Ins需要)

posted @ 2023-01-14 18:42  梅雨明夏  阅读(277)  评论(0编辑  收藏  举报