动手实现深度学习(13)池化层的实现
10.1 池化层的运算
传送门: https://www.cnblogs.com/greentomlee/p/12314064.html
github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning
池化层的forward
Pool分为三类 mean-pool, max-pool和min-pool, 本章只讨论max-pool
以下是forwad的运算:
https://blog.csdn.net/nanhuaibeian/article/details/100664570
池化层的backward的运算
Max-pool的反传是将原来的单元扩大stride_h*stride_w,其余的地方填充0
10.2 池化层的实现
1 class Pooling: 2 def __init__(self, pool_h, pool_w, stride=1, pad=0): 3 self.pool_h = pool_h 4 self.pool_w = pool_w 5 self.stride = stride 6 self.pad = pad 7 8 self.x = None 9 self.arg_max = None 10 11 def forward(self, x): 12 N, C, H, W = x.shape 13 out_h = int(1 + (H - self.pool_h) / self.stride) 14 out_w = int(1 + (W - self.pool_w) / self.stride) 15 16 col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad) 17 col = col.reshape(-1, self.pool_h * self.pool_w) 18 19 arg_max = np.argmax(col, axis=1) 20 out = np.max(col, axis=1) 21 out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2) 22 23 self.x = x 24 self.arg_max = arg_max 25 26 return out 27 28 def backward(self, dout): 29 dout = dout.transpose(0, 2, 3, 1) 30 31 pool_size = self.pool_h * self.pool_w 32 dmax = np.zeros((dout.size, pool_size)) 33 dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten() 34 dmax = dmax.reshape(dout.shape + (pool_size,)) 35 36 dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1) 37 dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad) 38 39 return dx
10.3 pool单元测试
测试程序:
1 # -*- coding: utf-8 -*- 2 # @File : test_im2col.py 3 # @Author: lizhen 4 # @Date : 2020/2/14 5 # @Desc : 测试im2col 6 import numpy as np 7 8 from src.common.util import im2col,col2im 9 from src.common.layers import Convolution,Pooling 10 11 12 if __name__ == '__main__': 13 raw_data = [3, 0, 4, 2, 14 6, 5, 4, 3, 15 3, 0, 2, 3, 16 1, 0, 3, 1, 17 18 1, 2, 0, 1, 19 3, 0, 2, 4, 20 1, 0, 3, 2, 21 4, 3, 0, 1, 22 23 4, 2, 0, 1, 24 1, 2, 0, 4, 25 3, 0, 4, 2, 26 6, 2, 4, 5 27 ] 28 29 raw_filter=[ 30 1, 1, 1, 1, 1, 1, 31 1, 1, 1, 1, 1, 1, 32 2, 2, 2, 2, 2, 2, 33 2, 2, 2, 2, 2, 2, 34 35 ] 36 37 38 39 input_data = np.array(raw_data) 40 filter_data = np.array(raw_filter) 41 42 x = input_data.reshape(1,3,4,4)# NCHW 43 W = filter_data.reshape(2,3,2,2) # NHWC 44 b = np.zeros(2) 45 # b = b.reshape((2,1)) 46 # col1 = im2col(input_data=x,filter_h=2,filter_w=2,stride=1,pad=0)#input_data, filter_h, filter_w, stride=1, pad=0 47 # print(col1) 48 49 # print("input_data.shape=%s"%str(input_data.shape)) 50 # print("W.shape=%s"%str(W.shape)) 51 # print("b.shape=%s"%str(b.shape)) 52 # conv = Convolution(W,b) # def __init__(self, W, b, stride=1, pad=0) 53 # out = conv.forward(x) 54 # print("bout.shape=%s"%str(out.shape)) 55 # print(out) 56 57 print("===================") 58 pool=Pooling( pool_h=2, pool_w=2, stride=2, pad=0) 59 out = pool.forward(x) 60 print(out.shape) 61 print(out)
对应输出:
我心匪石,不可转也。我心匪席,不可卷也。