maxpool3d修改成maxpool2d与maxpool1d方法

有时候遇到不支持maxpool3d的硬件或算子时候,可将其改成maxpool2d加上maxpool1d组合方式表示,经验证与maxpool3d结果完全一致,其实现细节如下:

代码:

import torch


class MaxPool3d_modify(torch.nn.Module):
    def __init__(self, kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)):
        super(MaxPool3d_modify, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.max_pool_2d = torch.nn.MaxPool2d(kernel_size[1:], self.stride[1:], padding[1:])
        self.max_pool_1d = torch.nn.MaxPool1d(kernel_size=kernel_size[0], stride=self.stride[0],
                                              padding=self.padding[0])  # stride is kernal_size

    def forward(self, x):
        x1 = self.max_pool_2d(x)
        x = x1.squeeze(0).permute(1, 2, 0)
        x = self.max_pool_1d(x)

        x = x.permute(2, 0, 1).unsqueeze(0)

        return x


if __name__ == '__main__':
    '''
    torch.nn.MaxPool3d处理维度4或5,[b,c,h,w]或[b,c,f,h,w] 处理维度为c,h,w或f,h,w
    torch.nn.MaxPool2d处理维度为4,[b,c,h,w]处理h,w维度pool
    torch.nn.MaxPool1d处理维度为3,[d1,d2,d3]处理d3维度pool
    '''
    input_ori = torch.rand(1, 128, 20, 90)  # 64,18,44  kernel_size=(1, 3, 3), stride=(2, 1, 2)
    model1 = torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 3, 2), padding=(0, 0, 0))
    model2 = MaxPool3d_modify(kernel_size=(1, 3, 3), stride=(4, 3, 2), padding=(0, 0, 0))
    o1 = model1(input_ori)
    print('\noutput1 shape: ', o1.shape)
    o2 = model2(input_ori)
    print('\noutput2 shape: ', o2.shape)
    output1 = o1.reshape(-1)
    output2 = o2.reshape(-1)
    n = 0
    for i, o in enumerate(output1):
        if o == output2[i]:
            n = n + 1
    print('precision', n / len(output1))
    

结果展示:

 

 

 

posted @ 2023-03-02 11:23  tangjunjun  阅读(244)  评论(0编辑  收藏  举报
https://rpc.cnblogs.com/metaweblog/tangjunjun