maxpool3d修改成maxpool2d与maxpool1d方法
有时候遇到不支持maxpool3d的硬件或算子时候,可将其改成maxpool2d加上maxpool1d组合方式表示,经验证与maxpool3d结果完全一致,其实现细节如下:
代码:
import torch class MaxPool3d_modify(torch.nn.Module): def __init__(self, kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)): super(MaxPool3d_modify, self).__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding self.max_pool_2d = torch.nn.MaxPool2d(kernel_size[1:], self.stride[1:], padding[1:]) self.max_pool_1d = torch.nn.MaxPool1d(kernel_size=kernel_size[0], stride=self.stride[0], padding=self.padding[0]) # stride is kernal_size def forward(self, x): x1 = self.max_pool_2d(x) x = x1.squeeze(0).permute(1, 2, 0) x = self.max_pool_1d(x) x = x.permute(2, 0, 1).unsqueeze(0) return x if __name__ == '__main__': ''' torch.nn.MaxPool3d处理维度4或5,[b,c,h,w]或[b,c,f,h,w] 处理维度为c,h,w或f,h,w torch.nn.MaxPool2d处理维度为4,[b,c,h,w]处理h,w维度pool torch.nn.MaxPool1d处理维度为3,[d1,d2,d3]处理d3维度pool ''' input_ori = torch.rand(1, 128, 20, 90) # 64,18,44 kernel_size=(1, 3, 3), stride=(2, 1, 2) model1 = torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 3, 2), padding=(0, 0, 0)) model2 = MaxPool3d_modify(kernel_size=(1, 3, 3), stride=(4, 3, 2), padding=(0, 0, 0)) o1 = model1(input_ori) print('\noutput1 shape: ', o1.shape) o2 = model2(input_ori) print('\noutput2 shape: ', o2.shape) output1 = o1.reshape(-1) output2 = o2.reshape(-1) n = 0 for i, o in enumerate(output1): if o == output2[i]: n = n + 1 print('precision', n / len(output1))
结果展示: