2019CVPR论文 HIgh Resolution Representation Learning for Human Pose Estimation代码解读

 ps:转载请注明出处,谢谢。

以下简称HRnet

这篇论文我拖更了好久,早在半年前我就说我要更新和这篇文献相关的代码研读,一直是懒,然后代码太长,分析代码真的要有决心+耐心+毅力,不然的话很容易放弃的,一件事情你做了百分之99就等同于没有做,行百里者半九十,就是这个道理,希望所有在这个领域内的小白通过阅读文献,编写代码来提升自己,相信你自己,你挺棒的。

HRnet由最基本的三种块构成。第一种是普通的3x3的卷积它的结构如下。

 

它的代码如下

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

第二种是BasicBlock,它的结构如下。

当inchannels和outchannels不想等时就进行将采样。它的代码如下

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

第三种结构是三层的残差块,结构如下图。这个结构里面有一个参数叫做expansion的参数,这个参数用来控制卷积的输入输出通道数。

其中BN层和RELU层我就不单独的画出来了,想看可以去原文中找相应的代码,这一部分的代码如下

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion,
                               momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

就和正常三层卷积的残差块是一样的道理。接下来就是高分辨率模块。首先会看到高分辨率模块的参数列表如下

class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
                 num_channels, fuse_method, multi_scale_output=True):
        super(HighResolutionModule, self).__init__()
        self._check_branches(
            num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        self.fuse_method = fuse_method
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        self.branches = self._make_branches(
            num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()
        self.relu = nn.ReLU(False)

check_branches()这个函数这个函数的作用是检查,在高分辨率模块中num_branches(int类型),和len(num_inchannels(里面的元素是int)),和len(num_channels(里面的元素是int))它们三个的值是否相等,如果不想等就报出异常。那么这三个变量是什么意思呢。我们首先看一下高分辨率模块的图。

这个图里面我画出的这个部分就是一个完整的高分辨率模块,num_branches代表的是有几个分支,就是融合的时候(交叉很多线那个地方),有几条线指向一组featuremaps,(不算多出来的新分支,那是新的stage要考虑的问题),那么我这个图里面是每个featuremaps组有两条线指向它,因此num_branches=2,num_inchannels和num_channels都是列表,他们表示的是featuremaps的输入通道数和输出通道数,因为同一个分支上的featuremaps的通道数是一致的,因此有几个分支(num_branches),len(num_inchannels)和len(num_channels)就是几。同样就有几种尺度的featuremaps的融合,这一点后面也会说到。下面是check_branches部分的代码。

def _check_branches(self, num_branches, blocks, num_blocks,
                        num_inchannels, num_channels):
        if num_branches != len(num_blocks):
            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
                num_branches, len(num_blocks))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
                num_branches, len(num_channels))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
                num_branches, len(num_inchannels))
            logger.error(error_msg)
            raise ValueError(error_msg)

下面一个函数是_make_one_branch,代码如下

def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.num_inchannels[branch_index],
                          num_channels[branch_index] * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
                            momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.num_inchannels[branch_index],
                            num_channels[branch_index], stride, downsample))
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(block(self.num_inchannels[branch_index],
                                num_channels[branch_index]))

        return nn.Sequential(*layers)

它的作用就是创建一个新的分支,就是

我画出的这个部分,因为HRnet的所有块都是基于两种残差块的,因此只有每个分支的第一个块是特殊的,我们首先要判断的是inchannels和channels*expansion是否相等,如果不相等(针对的是Bottleneck块),那么就要进行downsample的操作,将通道调整到一致,因为每个branch里面有num_blocks[branch_index]个块,除了第一个块不同,其他的全部相同,因此采用一个for循环,从1到num_blocks[branch_index]逐渐生成基本块即可。

make_branches函数是看看每个stage里面有多少branch,然后有几个就调用几次_make_one_branch函数,这部分代码就不再赘述了,就一个for循环,简单。

重点重点重点来了,我把字放大来讲,这部分实在实在是不好理解,如果你觉得你看起来没问题的话,那你不用看我唠叨了。

上面我讲解了各个参数的含义,fuse_layer有个地方特别难懂这也是我第一遍看没有看懂的原因,那么它说的到底是个啥意思呢,请看下面的图。

然后我们看一下代码如下

def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(nn.Sequential(
                        nn.Conv2d(num_inchannels[j],
                                  num_inchannels[i],
                                  1,
                                  1,
                                  0,
                                  bias=False),
                        nn.BatchNorm2d(num_inchannels[i], 
                                       momentum=BN_MOMENTUM),
                        nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i-j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                nn.BatchNorm2d(num_outchannels_conv3x3, 
                                            momentum=BN_MOMENTUM)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                nn.BatchNorm2d(num_outchannels_conv3x3,
                                            momentum=BN_MOMENTUM),
                                nn.ReLU(False)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

 

首先要注意的是fuselayers里面是不包含我画叉的那一条红色的线的,只包含粉色线和蓝色线的操作,那双重循环里面的i代表什么呢,i代表的当前融合的branch,上面的图我画出了当i=0时,所有的featuremaps都融合到0这个分支的featuremaps上面去,j代表组成融合的featuremaps所对应的branchindex,那么这时候要分三种情况讨论。

第一种情况:j>i

此时j所在分支的featuremaps的分辨率比i要小,通道数要多,那此时需要先使用卷积对其进行通道的改变,然后进行上采样,上采样因子即scale_factor的大小是2的(j-i)次方,比如j-i=1时那此时j就在i下面一个分支,他们俩的分辨率就差2倍,如果j-i=2,说明j在i下面两个分支的位置,此时分辨率相差四倍,因为上采样因子时2的2次方,就是4.

第二种情况,j=i时,j=i时说明要参与融合的分支j和目标分支i在同一个branch上面,因此什么都不用做

第三种情况j<i,那此时j所在分支的分辨率比i所在的目标分支的分辨率要大,因此要进行降采样,就是改变stride来采样,那作者引用了一个参数k代表降采样的次数k,k 的范围在[0,i-j-1]的这个闭区间内。k每进行循环加1的时候就进行一次卷积改变通道并且降采样。

当k=i-j-1的时候说明此时 j 分支就在 i 分支的上面一个分支,因此直接输出通道就是 i 所在分支的输入通道即可,同时令stride=2改变featuremaps的大小。当k!=i-j-1的时候说明 j 所在分支的分辨率比 i 所在分支的分辨率高出不止2倍,例如 j=0,i=3时,那么令outchannel=in channel[j],这个我也不知道为啥这么做哈哈哈,因为这样改变通道没有意义,不如到通道一直不变到最后再变成目标分支 i 的通道呢,我觉得很奇怪。因为当他们的分辨率只差两倍的时候通道仍相差四倍,然后k=i-j-1直接变到i所在通道数了,很迷惑不知道为什么作者这么做。搞得很复杂。

那么fuselayers最后长啥样呢,它其实是这样 fuse_layers[fuse_layer0[j=0 to i=0,j=1 to i=0,j=2 to i=0],fuse_layer1[j=0 to i=1,j=1 to i=1,j=2 to i=1],fuse_layer2[j=0 to i=2,j=1 to i=2,j=2 to i=2]],就是这样一个列表,是个二维列表(其实就是数组,python没有数组就用列表代替),列表中的每一个to操作都是一个sequential或者none,代表j分支到目标i分支的操作。接下来就是forward了,这部分很简单这里不再赘述,代码如下。

 def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i] = self.branches[i](x[i])

        x_fuse = []
        for i in range(len(self.fuse_layers)):
            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y = y + x[j]
                else:
                    y = y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        return x_fuse

 

刚才那个图上面画叉的那条红线不在fuse_layers中,它在transition_layers中,代码中有体现,它的作用是每当一个fuse_layers产生时会生成新的分支,新的分支的输入源于它上一个stage的上一层branch的旧的分支,因此需要额外的考虑一下,不过代码不是很难理解,看看就懂了。代码如下

def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(nn.Sequential(
                        nn.Conv2d(num_channels_pre_layer[i],
                                  num_channels_cur_layer[i],
                                  3,
                                  1,
                                  1,
                                  bias=False),
                        nn.BatchNorm2d(
                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=True)))
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i+1-num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i-num_branches_pre else inchannels
                    conv3x3s.append(nn.Sequential(
                        nn.Conv2d(
                            inchannels, outchannels, 3, 2, 1, bias=False),
                        nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=True)))
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)

HRnet是一个stage一个stage构建的,即横向构建,每次都需要判断pre_stage_channels和cur_stage_channels是否相等以判断transition部分要做什么,这点非常好理解。

这是之前画的可视化之后的结构图,它的transition和论文中说的不同,但是我看了另外一个博主的文章她也是这么画的所以估计是没有问题的,好了终于把这篇博客写完了,累啊,

 

posted @ 2020-06-27 14:16  daremosiranaihana  阅读(706)  评论(1编辑  收藏  举报