动手实现深度学习(13)池化层的实现

10.1 池化层的运算

传送门: https://www.cnblogs.com/greentomlee/p/12314064.html

github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning

池化层的forward

Pool分为三类 mean-pool, max-pool和min-pool, 本章只讨论max-pool

以下是forwad的运算:

https://blog.csdn.net/nanhuaibeian/article/details/100664570 wps82

池化层的backward的运算

Max-pool的反传是将原来的单元扩大stride_h*stride_w,其余的地方填充0

wps83

 

10.2 池化层的实现

  1 class Pooling:
  2     def __init__(self, pool_h, pool_w, stride=1, pad=0):
  3         self.pool_h = pool_h
  4         self.pool_w = pool_w
  5         self.stride = stride
  6         self.pad = pad
  7 
  8         self.x = None
  9         self.arg_max = None
 10 
 11     def forward(self, x):
 12         N, C, H, W = x.shape
 13         out_h = int(1 + (H - self.pool_h) / self.stride)
 14         out_w = int(1 + (W - self.pool_w) / self.stride)
 15 
 16         col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
 17         col = col.reshape(-1, self.pool_h * self.pool_w)
 18 
 19         arg_max = np.argmax(col, axis=1)
 20         out = np.max(col, axis=1)
 21         out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
 22 
 23         self.x = x
 24         self.arg_max = arg_max
 25 
 26         return out
 27 
 28     def backward(self, dout):
 29         dout = dout.transpose(0, 2, 3, 1)
 30 
 31         pool_size = self.pool_h * self.pool_w
 32         dmax = np.zeros((dout.size, pool_size))
 33         dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
 34         dmax = dmax.reshape(dout.shape + (pool_size,))
 35 
 36         dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
 37         dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
 38 
 39         return dx

 

10.3 pool单元测试

测试的数据如下:image

 

im2col以后的数据:image

 

Maxpool以后的数据:image

 

测试程序:

  1 # -*- coding: utf-8 -*-
  2 # @File  : test_im2col.py
  3 # @Author: lizhen
  4 # @Date  : 2020/2/14
  5 # @Desc  : 测试im2col
  6 import numpy as np
  7 
  8 from src.common.util import im2col,col2im
  9 from src.common.layers import Convolution,Pooling
 10 
 11 
 12 if __name__ == '__main__':
 13     raw_data = [3, 0, 4, 2,
 14                 6, 5, 4, 3,
 15                 3, 0, 2, 3,
 16                 1, 0, 3, 1,
 17 
 18                 1, 2, 0, 1,
 19                 3, 0, 2, 4,
 20                 1, 0, 3, 2,
 21                 4, 3, 0, 1,
 22 
 23                 4, 2, 0, 1,
 24                 1, 2, 0, 4,
 25                 3, 0, 4, 2,
 26                 6, 2, 4, 5
 27     ]
 28 
 29     raw_filter=[
 30         1,    1,    1,    1,    1,    1,
 31         1,    1,    1,    1,    1,    1,
 32         2,    2,    2,    2,    2,   2,
 33         2,    2,    2,    2,    2,   2,
 34 
 35     ]
 36 
 37 
 38 
 39     input_data = np.array(raw_data)
 40     filter_data = np.array(raw_filter)
 41 
 42     x = input_data.reshape(1,3,4,4)# NCHW
 43     W = filter_data.reshape(2,3,2,2) # NHWC
 44     b = np.zeros(2)
 45     # b = b.reshape((2,1))
 46     # col1 = im2col(input_data=x,filter_h=2,filter_w=2,stride=1,pad=0)#input_data, filter_h, filter_w, stride=1, pad=0
 47     # print(col1)
 48 
 49     # print("input_data.shape=%s"%str(input_data.shape))
 50     # print("W.shape=%s"%str(W.shape))
 51     # print("b.shape=%s"%str(b.shape))
 52     # conv = Convolution(W,b) # def __init__(self, W, b, stride=1, pad=0)
 53     # out = conv.forward(x)
 54     # print("bout.shape=%s"%str(out.shape))
 55     # print(out)
 56 
 57     print("===================")
 58     pool=Pooling( pool_h=2, pool_w=2, stride=2, pad=0)
 59     out = pool.forward(x)
 60     print(out.shape)
 61     print(out)

对应输出:

wps84

posted @ 2022-09-12 18:35  修雨轩陈  阅读(132)  评论(0编辑  收藏  举报