pytorch 不使用转置卷积来实现上采样
上采样(upsampling)一般包括2种方式:
- Resize,如双线性插值直接缩放,类似于图像缩放,概念可见最邻近插值算法和双线性插值算法——图像缩放
- Deconvolution,也叫Transposed Convolution,可见逆卷积的详细解释ConvTranspose2d(fractionally-strided convolutions)
第二种方法如何用pytorch实现可见上面的链接
这里想要介绍的是如何使用pytorch实现第一种方法:
- 有两个模块都支持该上采样的实现,一个是torch.nn模块,详情可见:pytorch torch.nn 实现上采样——nn.Upsample (但是现在这种方法已经不推荐使用了,最好使用下面的方法)
- 一个是torch.nn.funtional模块,详情可见:pytorch torch.nn.functional实现插值和上采样
举例:
1)使用torch.nn模块实现一个生成器为:
import torch.nn as nn import torch.nn.functional as F class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride): super(ConvLayer, self).__init__() padding = kernel_size // 2 self.reflection_pad = nn.ReflectionPad2d(padding) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride) def forward(self, x): out = self.reflection_pad(x) out = self.conv(out) return out class Generator(nn.Module): def __init__(self, in_channels): super(Generator, self).__init__() self.in_channels = in_channels self.encoder = nn.Sequential( ConvLayer(self.in_channels, 32, 3, 2), nn.BatchNorm2d(32), nn.ReLU(), ConvLayer(32, 64, 3, 2), nn.BatchNorm2d(64), nn.ReLU(), ConvLayer(64, 128, 3, 2), ) upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.decoder = nn.Sequential( upsample, nn.Conv2d(128, 64, 1), nn.BatchNorm2d(64), nn.ReLU(), upsample, nn.Conv2d(64, 32, 1), nn.BatchNorm2d(32), nn.ReLU(), upsample, nn.Conv2d(32, 3, 1), nn.Tanh() ) def forward(self, x): x = self.encoder(x) out = self.decoder(x) return out def test(): net = Generator(3) for module in net.children(): print(module) x = Variable(torch.randn(2,3,224,224)) output = net(x) print('output :', output.size()) print(type(output)) if __name__ == '__main__': test()
返回:
model.py .Sequential( (0): ConvLayer( (reflection_pad): ReflectionPad2d((1, 1, 1, 1)) (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2)) ) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() (3): ConvLayer( (reflection_pad): ReflectionPad2d((1, 1, 1, 1)) (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2)) ) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU() (6): ConvLayer( (reflection_pad): ReflectionPad2d((1, 1, 1, 1)) (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2)) ) ) Sequential( (0): Upsample(scale_factor=2, mode=bilinear) (1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): ReLU() (4): Upsample(scale_factor=2, mode=bilinear) (5): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): ReLU() (8): Upsample(scale_factor=2, mode=bilinear) (9): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1)) (10): Tanh() ) output : torch.Size([2, 3, 224, 224]) <class 'torch.Tensor'>
但是这个会有警告:
UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
可使用torch.nn.functional模块替换为:
import torch.nn as nn import torch.nn.functional as F class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride): super(ConvLayer, self).__init__() padding = kernel_size // 2 self.reflection_pad = nn.ReflectionPad2d(padding) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride) def forward(self, x): out = self.reflection_pad(x) out = self.conv(out) return out class Generator(nn.Module): def __init__(self, in_channels): super(Generator, self).__init__() self.in_channels = in_channels self.encoder = nn.Sequential( ConvLayer(self.in_channels, 32, 3, 2), nn.BatchNorm2d(32), nn.ReLU(), ConvLayer(32, 64, 3, 2), nn.BatchNorm2d(64), nn.ReLU(), ConvLayer(64, 128, 3, 2), ) self.decoder1 = nn.Sequential( nn.Conv2d(128, 64, 1), nn.BatchNorm2d(64), nn.ReLU() ) self.decoder2 = nn.Sequential( nn.Conv2d(64, 32, 1), nn.BatchNorm2d(32), nn.ReLU() ) self.decoder3 = nn.Sequential( nn.Conv2d(32, 3, 1), nn.Tanh() ) def forward(self, x): x = self.encoder(x) x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) x = self.decoder1(x) x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) x = self.decoder2(x) x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) out = self.decoder3(x) return out def test(): net = Generator(3) for module in net.children(): print(module) x = Variable(torch.randn(2,3,224,224)) output = net(x) print('output :', output.size()) print(type(output)) if __name__ == '__main__': test()
返回:
model.py .Sequential( (0): ConvLayer( (reflection_pad): ReflectionPad2d((1, 1, 1, 1)) (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2)) ) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() (3): ConvLayer( (reflection_pad): ReflectionPad2d((1, 1, 1, 1)) (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2)) ) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU() (6): ConvLayer( (reflection_pad): ReflectionPad2d((1, 1, 1, 1)) (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2)) ) ) Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) Sequential( (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) Sequential( (0): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1)) (1): Tanh() ) output : torch.Size([2, 3, 224, 224]) <class 'torch.Tensor'>