动手学深度学习——卷积操作

卷积

卷积概念

  • 卷积原属于信号处理中的一种运算,引入CNN中,作为从输入中提取特征的基本操作
  • 补零:在输入端外侧填补0值使得卷积输出结果满足某种大小,在外侧的每一边都添加0值,使得输出可以达到某种预定形状
  • 跨步:卷积核在输入上滑动时每次移动到下一步的距离
    img

使用张量实现卷积

import torch
a = torch.arange(16).view(4,4)
a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
b = a.unfold(0,3,1) #按照第0维度,以每3个元素,跨度为1进行展开
b
tensor([[[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]],
        [[ 4,  8, 12],
         [ 5,  9, 13],
         [ 6, 10, 14],
         [ 7, 11, 15]]])

torch.tensor.unfold可以按照指定的维度,以一定的间隔对原始张量进行分片,然后返回重整后的张量。

c = b.unfold(1,3,1)
c
tensor([[[[ 0,  1,  2],
          [ 4,  5,  6],
          [ 8,  9, 10]],
         [[ 1,  2,  3],
          [ 5,  6,  7],
          [ 9, 10, 11]]],
        [[[ 4,  5,  6],
          [ 8,  9, 10],
          [12, 13, 14]],
         [[ 5,  6,  7],
          [ 9, 10, 11],
          [13, 14, 15]]]])
c.shape
torch.Size([2, 2, 3, 3])

上述张量c已经实现了按照3*3滑动窗大小,跨步为1将张量a展开为4为张量。

带补零和跨步设置的4维卷积层操作:

import torch
def conv2d(x,weight,bias,stride,pad):
    n,c,h_in,w_in = x.shape
    d,c,k,j = weight.shape
    x_pad = torch.zeros(n,c,h_in+2*pad,w_in+2*pad).to(x.device)
    x_pad[:,:,pad:-pad,pad:-pad] = x #对输入进行补零操作
    x_pad = x_pad.unfold(2,k,stride)
    x_pad = x_pad.unfold(3,j,stride) #按照滑动窗展开
    out = torch.einsum(
        'nchwkj,dckj->ndhw',
        x_pad,weight)#按照滑动窗相乘,并将所有输入通道上卷积结果累加
    out = out + bias.view(1,-1,1,1) #添加偏置
    return out
import torch.nn.functional as F
x = torch.randn(2,3,5,5,requires_grad=True)
w = torch.randn(4,3,3,3,requires_grad=True)
b = torch.randn(4,requires_grad=True)
stride = 2
pad = 2
torch_out = F.conv2d(x,w,b,stride,pad)
my_out = conv2d(x,w,b,stride,pad)
torch_out
my_out
tensor([[[[ 0.7389,  2.6153,  1.6206,  0.7432],
          [ 3.3159, -1.5308,  9.9275, -1.2443],
          [ 3.0869, -6.5276,  5.8508, -2.7660],
          [-1.9878, -6.0596, -0.6992, -3.3871]],

         [[-0.5907,  1.0378, -2.1682, -1.2919],
          [-0.5426, -0.7781,  4.4606,  2.6235],
          [-4.9208,  2.5762, -0.1033, -2.2686],
          [ 0.8438, -2.4514,  2.3441, -1.5637]],

         [[ 1.6342,  1.5391,  4.0431,  4.2984],
          [-1.9671,  1.6227, -3.0477,  1.4082],
          [ 2.1579,  0.1513,  0.3556, -1.5150],
          [ 1.8514,  2.6099,  3.6082,  0.9121]],

         [[ 0.3274, -0.2762,  0.1335,  0.9362],
          [ 1.9674, -9.8901,  4.4833, -4.0852],
          [-4.3262,  0.1775, -0.3596,  1.7832],
          [ 3.7039, -2.4898,  5.7371, -1.6463]]],


        [[[-0.6281,  2.5599, -1.1673, -0.2803],
          [ 0.3624, -3.0622,  0.9032, -2.2624],
          [ 5.2199, 10.0974, -6.2536,  3.3783],
          [ 0.7550,  6.5702,  1.6907,  0.4545]],
...

         [[ 0.3033, -4.3135,  3.3039, -0.7272],
          [ 5.8496, -4.2414, -9.7936, -3.2630],
          [-5.2852, -5.4366,  8.8947, -1.0325],
          [-0.5529,  2.5634, -3.4046,  0.8185]]]], grad_fn=<AddBackward0>)
posted @ 2024-05-06 10:23  Sun-Wind  阅读(10)  评论(0编辑  收藏  举报