pytorch常用的padding函数
1)ReflectionPad2d
CLASS torch.nn.ReflectionPad2d(padding)
使用输入边界的反射来填充输入tensor
对于N维的填充,使用torch.nn.functional.pad()
参数:
- padding(int, tuple):指定填充的大小。如果是一个整数值a,则所有边界都使用相同的填充数,等价于输入(a,a,a,a)。如果是大小为4的元组,则表示 (padding_leftpadding_left, padding_rightpadding_right, padding_toppadding_top, padding_bottompadding_bottom)
形状:
- 输入:(N,C,Hin,Win)
- 输出:(N,C,Hout,Wout)
计算式子为:
- Hout=Hin+padding_top+padding_bottom
- Wout=Win+padding_left+padding_right
举例:
(deeplearning) userdeMacBook-Pro:pytorch-CycleGAN-and-pix2pix user$ python Python 3.6.8 |Anaconda, Inc.| (default, Dec 29 2018, 19:04:46) [GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)] on darwin Type "help", "copyright", "credits" or "license" for more information. >>> from torch import nn >>> import torch >>> m = nn.ReflectionPad2d(2) >>> input = torch.arange(9, dtype=torch.float).reshape(1,1,3,3) >>> m(input) tensor([[[[8., 7., 6., 7., 8., 7., 6.], [5., 4., 3., 4., 5., 4., 3.], [2., 1., 0., 1., 2., 1., 0.], [5., 4., 3., 4., 5., 4., 3.], [8., 7., 6., 7., 8., 7., 6.], [5., 4., 3., 4., 5., 4., 3.], [2., 1., 0., 1., 2., 1., 0.]]]]) >>> m = nn.ReflectionPad2d(1) >>> m(input) tensor([[[[4., 3., 4., 5., 4.], [1., 0., 1., 2., 1.], [4., 3., 4., 5., 4.], [7., 6., 7., 8., 7.], [4., 3., 4., 5., 4.]]]])
从例子可见,填充的值就是对应边界旁边的值,即反射填充
⚠️padding的大小要小于输入的大小,否则报错:
>>> m = nn.ReflectionPad2d(3) >>> m(input) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__ result = self.forward(*input, **kwargs) File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/nn/modules/padding.py", line 172, in forward return F.pad(input, self.padding, 'reflect') File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/nn/functional.py", line 2685, in pad ret = torch._C._nn.reflection_pad2d(input, pad) RuntimeError: Argument #4: Padding size should be less than the corresponding input dimension, but got: padding (3, 3) at dimension 3 of input [1, 1, 3, 3] >>>
2)ReplicationPad2d
CLASS torch.nn.ReplicationPad2d(padding)
使用输入边界的复制值来填充输入tensor
对于N维的填充,使用torch.nn.functional.pad()
参数:
- padding(int, tuple):指定填充的大小。如果是一个整数值a,则所有边界都使用相同的填充数,等价于输入(a,a,a,a)。如果是大小为4的元组,则表示 (padding_leftpadding_left, padding_rightpadding_right, padding_toppadding_top, padding_bottompadding_bottom)
形状:
- 输入:(N,C,Hin,Win)
- 输出:(N,C,Hout,Wout)
计算式子为:
- Hout=Hin+padding_top+padding_bottom
- Wout=Win+padding_left+padding_right
举例:
>>> m = nn.ReplicationPad2d(2) >>> m(input) tensor([[[[0., 0., 0., 1., 2., 2., 2.], [0., 0., 0., 1., 2., 2., 2.], [0., 0., 0., 1., 2., 2., 2.], [3., 3., 3., 4., 5., 5., 5.], [6., 6., 6., 7., 8., 8., 8.], [6., 6., 6., 7., 8., 8., 8.], [6., 6., 6., 7., 8., 8., 8.]]]]) >>>
可见填充的边界是啥,填充的值就是啥
3)ZeroPad2d
CLASS torch.nn.ZeroPad2d(padding)
使用0填充输入tensor的边界
对于N维的填充,使用torch.nn.functional.pad()
参数:
- padding(int, tuple):指定填充的大小。如果是一个整数值a,则所有边界都使用相同的填充数,等价于输入(a,a,a,a)。如果是大小为4的元组,则表示 (padding_leftpadding_left, padding_rightpadding_right, padding_toppadding_top, padding_bottompadding_bottom)
形状:
- 输入:(N,C,Hin,Win)
- 输出:(N,C,Hout,Wout)
计算式子为:
- Hout=Hin+padding_top+padding_bottom
- Wout=Win+padding_left+padding_right
举例:
>>> m = nn.ZeroPad2d(2) >>> m(input) tensor([[[[0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 1., 2., 0., 0.], [0., 0., 3., 4., 5., 0., 0.], [0., 0., 6., 7., 8., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.]]]]) >>>
ConstantPad2d
CLASS torch.nn.ConstantPad2d(padding, value)
使用一个常量值填充输入tensor边界
对于N维的填充,使用torch.nn.functional.pad()
参数:
- padding(int, tuple):指定填充的大小。如果是一个整数值a,则所有边界都使用相同的填充数,等价于输入(a,a,a,a)。如果是大小为4的元组,则表示 (padding_leftpadding_left, padding_rightpadding_right, padding_toppadding_top, padding_bottompadding_bottom)
- value:填充的常量值
形状:
- 输入:(N,C,Hin,Win)
- 输出:(N,C,Hout,Wout)
计算式子为:
- Hout=Hin+padding_top+padding_bottom
- Wout=Win+padding_left+padding_right
举例:
>>> m = nn.ConstantPad2d(2,3.99) >>> m(input) tensor([[[[3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900], [3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900], [3.9900, 3.9900, 0.0000, 1.0000, 2.0000, 3.9900, 3.9900], [3.9900, 3.9900, 3.0000, 4.0000, 5.0000, 3.9900, 3.9900], [3.9900, 3.9900, 6.0000, 7.0000, 8.0000, 3.9900, 3.9900], [3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900], [3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900, 3.9900]]]]) >>>