卷积 - 用pytorch计算

对于一张图片,每个卷积核的通道数都和图片通道数一样,用n个卷积核进行卷积得到的结果就是一个n通道的特征图。

卷积的步长(stride)和padding决定了产生的特征图的size

对于一张7*7*3的图片(长7,宽7,3通道),使用两个3*3*3的卷积核(长3,宽3,3通道)进行卷积,如下图:

image

最左边一列是原始图片的3个通道的数据,中间两列红色的是2个卷积核。最右边一列是卷积得到的feature map,2通道。
卷积核的3通道分别和图像的3个通道进行element-wise的相乘,结果再相加。
比如,上图绿色框中的3的计算过程就是每行的红框和对应的蓝框做卷积得到3个数,分别是1,1,0. 然后加上下面的Bias,结果就是3.

pytorch算上图卷积

import torch
import torch.nn as nn

# 输入是4维(1 * 3 * 5 * 5)
# 1张图,3通道,高5,宽5
x = [[
    [
        [0,1,1,0,2],
        [0,2,2,1,1],
        [0,2,1,1,2],
        [1,0,2,1,2],
        [0,1,1,1,1]
    ],
    [
        [2,2,1,0,0],
        [2,1,0,0,1],
        [2,2,1,2,2],
        [1,2,0,1,2],
        [1,2,0,2,1]
    ],
    [
        [0,0,1,0,2],
        [1,1,1,1,1],
        [2,0,2,1,1],
        [1,1,1,0,0],
        [0,0,0,1,0]
    ]
]]

w0  =  [[[-1,  0,  1],
         [-1,  0,  1],
         [ 0,  1,  0]],

        [[ 1, -1,  1],
         [ 1,  1,  0],
         [-1,  0, -1]],

        [[ 0,  1,  0],
         [-1,  0,  1],
         [-1,  0,  0]]]

w1  =  [[[-1,  1,  0],
         [ 1,  0, -1],
         [-1,  0,  1]],

        [[ 0,  0,  1],
         [ 0,  1,  1],
         [ 1,  1, -1]],

        [[ 0,  1,  0],
         [-1, -1,  1],
         [ 1, -1, -1]]]
bias = [1,0]
# 转换成float类型的Tensor,后面会自动算梯度,要用float。
x = torch.tensor(x).float()
w = nn.Conv2d(3,2,3,padding=1,stride=2) # 创建卷积函数
w.weight.data = torch.tensor([w0,w1]).float() # 使用自定义卷积核,不然pytorch默认会自己生成
w.bias.data = torch.tensor(bias).float()
output = w(x)
'''
不想转成float的话,就关闭自动计算梯度,也省显存
x = torch.tensor(x)
w = nn.Conv2d(3,2,3,padding=1,stride=2)
w.weight.data = torch.tensor([w0,w1])
w.bias.data = torch.tensor(bias)
with torch.no_grad():
    output = w(x)
'''
print(output)

输出如下:

tensor([[[[ 3.,  3.,  1.],
          [ 6.,  3.,  3.],
          [ 5.,  9.,  0.]],

         [[ 4.,  0., -2.],
          [-1.,  6.,  4.],
          [ 6.,  7.,  2.]]]], grad_fn=<ThnnConv2DBackward>)

关闭梯度计算后的输出:

tensor([[[[ 3,  3,  1],
          [ 6,  3,  3],
          [ 5,  9,  0]],

         [[ 4,  0, -2],
          [-1,  6,  4],
          [ 6,  7,  2]]]])
posted @ 2021-10-10 12:17  王冰冰  阅读(395)  评论(0编辑  收藏  举报