MobileNet V3与Lite R-ASPP 总结

MobileNetV3 主要贡献

论文地址:https://arxiv.org/abs/1905.02244
代码地址:

一句话总结: mobilienet v3主要在神经网络结构搜索NetAdapt算法上,针对原始算法NetAdapt 在samll mobile models上优化不好的问题,非线性激活函数的计算复杂度高以及mobilenet v2的latency 问题提出解决方案:

  • Complementary search techniques(主要针对神经网络搜索的NetAdapt算法在small mobile models上accuracy的变化比latency的变化剧烈的问题,修改了优化目标,原始论文最小化accuracy change,当前论文最小化latency change 与accuracy change的比值)
  • A new efficient versions of nonlinearties practical (这部分主要解决非线性映射函数 s w i s h   x = x ⋅ σ ( x ) swish\ x=x \cdot \sigma (x) swish x=xσ(x)中的sigmoid函数计算复杂度高的问题,提出 h − s w i s h [ x ] = x ⋅ R e L u 6 ( x + 3 ) 6 h-swish[x]=x\cdot \frac{ReLu6(x+3)}{6} hswish[x]=x6ReLu6(x+3)近似代替原始映射函数)
  • A new efficient network design (这部分主要针对mobilenet v2的inverted bottleneck会产生较大的latency的问题, 将特征产生层后的avg-pooling前移,保证精度的情况下减少了latency;其次是优化了3*3卷积层的filter数目,从32优化成16)
  • A new efficient segmentation decoder(这部分主要轻量化了deeplab的ASPP结构,具体而言将ASPP与Squeeze and excitation,skip connection等trick结合,简化了deeplab v3的参数)
1. NetAapt 的搜索过程

每一步产生新的proposals,每个proposal相较之前的proposal在latency上有至少 δ \delta δ的降低。对于每个proposal,利用上一步pretrain的model来填充和剪枝新的网路结构,然后fine tune T 步直到model的accuracy满足要求。利用一些定义好的metric选择最佳的proposal.

NetAdapt采用的metric是accuracy的变化量,mobilenet v3采用的是accurcy的变化量与latency变化量的比值,只要解决移动端上小模型的accuracy与latency变化不一致的问题。

2. Redesign the expensive layers
  • 2.1 针对mobilenet v2的inverted bottleneck带来的latency上的问题,将特征产生层后的avg-pooling layer前移

  • 2.2 相比mobilenet v2,mobilenet v3加入了squeeze-excition net的思想

  • 2.3 SE module 实现:

class SeModule(nn.Module):
   def __init__(self, in_size, reduction=4):
       super(SeModule, self).__init__()
       self.se = nn.Sequential(
           nn.AdaptiveAvgPool2d(1),
           nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
           nn.BatchNorm2d(in_size // reduction),
           nn.ReLU(inplace=True),
           nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
           nn.BatchNorm2d(in_size),
           hsigmoid()
       )

   def forward(self, x):
       return x * self.se(x)
  • 2.4 mobilenet v3 block实现
class Block(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
        super(Block, self).__init__()
        self.stride = stride
        self.se = semodule

        self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(expand_size)
        self.nolinear1 = nolinear
        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
        self.bn2 = nn.BatchNorm2d(expand_size)
        self.nolinear2 = nolinear
        self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_size)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_size != out_size:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_size),
            )

    def forward(self, x):
        out = self.nolinear1(self.bn1(self.conv1(x)))
        out = self.nolinear2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.se != None:
            out = self.se(out)
        out = out + self.shortcut(x) if self.stride==1 else out
        return out
3. Nonlinearities,非线性激活函数的改进

提出 h − s w i s h [ x ] = x ⋅ R e L u 6 ( x + 3 ) 6 h-swish[x]=x\cdot \frac{ReLu6(x+3)}{6} hswish[x]=x6ReLu6(x+3)近似代替原始映射函数 s w i s h   x = x ⋅ σ ( x ) swish\ x=x \cdot \sigma (x) swish x=xσ(x),减少sigmoid函数在计算上带来的latency问题。同时作者注意到 h − s w i s h [ x ] h-swish[x] hswish[x]函数的作用范围在整个网络的后半部分效果较好。

pytorch 版本:

class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out
class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out
4. mobilenet v3 large与mobilenet v3 small结构
  • 4.1 mobilenet v3 large

  • 4.2 mobilenet v3 large 实现

class MobileNetV3_Large(nn.Module):
   def __init__(self, num_classes=1000):
       super(MobileNetV3_Large, self).__init__()
       self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
       self.bn1 = nn.BatchNorm2d(16)
       self.hs1 = hswish()

       self.bneck = nn.Sequential(
           Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
           Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
           Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
           Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
           Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
           Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
           Block(3, 40, 240, 80, hswish(), None, 2),
           Block(3, 80, 200, 80, hswish(), None, 1),
           Block(3, 80, 184, 80, hswish(), None, 1),
           Block(3, 80, 184, 80, hswish(), None, 1),
           Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
           Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
           Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
           Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
           Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
       )
       self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
       self.bn2 = nn.BatchNorm2d(960)
       self.hs2 = hswish()
       self.linear3 = nn.Linear(960, 1280)
       self.bn3 = nn.BatchNorm1d(1280)
       self.hs3 = hswish()
       self.linear4 = nn.Linear(1280, num_classes)
       self.init_params()
       
   def forward(self, x):
       out = self.hs1(self.bn1(self.conv1(x)))
       out = self.bneck(out)
       out = self.hs2(self.bn2(self.conv2(out)))
       out = F.avg_pool2d(out, 7)
       out = out.view(out.size(0), -1)
       out = self.hs3(self.bn3(self.linear3(out)))
       out = self.linear4(out)
       return out
  • 4.3 mobilenet v3 small
  • 4.4 mobilenet v3 small 代码实现
class MobileNetV3_Small(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetV3_Small, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()

        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
            Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
            Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
            Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
        )


        self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(576)
        self.hs2 = hswish()
        self.linear3 = nn.Linear(576, 1280)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = hswish()
        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()
        
    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.hs3(self.bn3(self.linear3(out)))
        out = self.linear4(out)
        return out
5. Lite R-ASPP

Lite R-ASPP主要将mobilenet v3的最后一个block中不同resolution的信息,通过与SE-net,skip-connection相结合,实现移动端的图像分割。

tensorflow 版本 https://github.com/xiaochus/MobileNetV3/blob/master/model/LR_ASPP.py

参考文献
  • https://arxiv.org/abs/1905.02244
  • https://github.com/xiaolai-sqlai/mobilenetv3
  • https://github.com/xiaochus/MobileNetV3
posted @ 2021-01-02 12:48  xzhws  阅读(574)  评论(0编辑  收藏  举报