[pytorch][模型压缩] 通道裁剪后的模型设计——以MobileNet和ResNet为例
说明
模型裁剪可分为两种,一种是稀疏化裁剪,裁剪的粒度为值级别,一种是结构化裁剪,最常用的是通道裁剪。通道裁剪是减少输出特征图的通道数,对应的权值是卷积核的个数。
问题
通常模型裁剪的三个步骤是:1. 判断网络中不重要的通道 2. 删减掉不重要的通道(一般不会立即删,加mask等到评测时才开始删) 3. 将模型导出,然后进行finetue恢复精度。
步骤1,2涉及到非常多的标准和方法,这里不去深究。但是到第3步的时候,怎么导出网络,看似很简单的问题,但是如果碰到resnet这种,是要花费时间研究细节的,而且目前还没有人专门讲这块(实际上是个工程实现问题),下面来详细说说。
以MobileNet为代表的模型
先考虑以mobilenet为代表的模型,mobilenet中包含了一系列块,每块中包含了深度可分离卷积和点卷积,然后整个模型就是一系列block块的堆叠,在目前很多模型中都具有代表性。
首先我们只考虑了模型的11卷积,因为11卷积是最耗算力的,而33卷积的裁剪实际上没有必要,意味可分离意味着将输入特征图的信息丢掉,与其丢掉,那不如在一开始就不去计算要丢掉的那部分,而不计算的那部分正是由前一层的11点卷积得到的,也就是说改变前一层的输出通道,就等同于对当前的可分离卷积的裁剪。
然后问题就只剩下11卷积核的裁剪了,那么需要在模型初始化时设置不同的profile,来实现不同结构的模型裁剪模型,这里代码中的例子是将第一个block中11卷积核的128通道裁剪为64通道,其他通道可依次次类推。
class MobileNet(nn.Module):
def __init__(self, n_class, profile='normal', channels=None):
self.channels = [32, 64, 104, 128, 248, 224, 456, 296, 456, 224, 104, 104, 208, 208]
if channels:
self.channsels = channels
super(MobileNet, self).__init__()
# original
if profile == 'normal':
in_planes = 32
cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), (1024,1)]
# 0.5 AMC
elif profile == '0.5flops':
in_planes = self.channels[0]
strides = [1, 2, 1, 2, 1, 2, 1,1,1,1,1, 2,1]
cfg = list(zip(self.channels[1:], strides))
else:
raise NotImplementedError
而在make_layers部分,需要判断当前stride, 有三次stride,每次缩放一倍,默认stride都是1,当然也可以把stride全列举出来,就不用判断了。
def _make_layers(self, in_planes, cfg, layer):
layers = []
for x in cfg:
out_planes = x if isinstance(x, int) else x[0]
stride = 1 if isinstance(x, int) else x[1]
layers.append(layer(in_planes, out_planes, stride))
in_planes = out_planes
return nn.Sequential(*layers)
以Resnet50为代表的模型
前面解决了mobileNet的问题,其实也是一个基本网络架构下的裁剪问题,但是目前的网络往往具有复杂的连接,比如像resnet这样,具有残差结构的单元块,这意味着残差部分需要单独处理。
在我压缩完成得到压缩配置之后,先写了简单版本的resnet_pruning版本,这是最朴素的思想:
def ResNet50_Pruning(**kwarg):
model = ResNet(Bottleneck_Pruning, [3,4,6,3], **kwarg)
p = 0
actions = [3, 56, 64, 64, 48, 240, 16, 64, 152, 32, 32, 152, 120, 104, 216, 368, 112, 32, 480, 112, 120, 504, 88, 104, 104, 240, 184, 368, 768, 200, 200, 640, 232, 192, 976, 248, 192, 760, 160, 208, 584, 208, 248, 968, 496, 224, 208, 416, 104, 104, 416, 104, 104, 416]
for i, m in enumerate(model.modules()):
if type(m) in (nn.Conv2d, nn.Linear):
if type(m) == nn.Conv2d and m.groups == m.in_channels: # depth-wise conv, buffer
continue
else:
if type(m) is nn.Linear:
m.in_features = actions[p]
else:
m.in_channels = actions[p]
m.out_channels = actions[p+1]
p += 1
return model
将每一层对应的actions都找到,然后令其channel都做出改变,这无疑是最直观的写法,但是由于CONV之后往往带着BN层,当改完CONV之后,你发现BN还是原有的值,这就会使得维度不匹配。当然我有参考了Nvi-Lab的写法,可以先建模型,然后获取压缩的action的裁剪通道,然后重建一个new_conv代替原有的conv,这样写也行,也是一种思路,不过我觉得这样不优雅,而且容易漏东西。
然后我使用了另外一个思路,在建立模型的时候就建立一个裁剪之后的模型,但是由于resnet50有多个blocks,然后每个block中是一个瓶颈结构,这就需要你定位到哪一个block,以及该block中哪一个卷积,这样就存在的一个缺陷就是需要全局的index来标记当前层是第几层。具体的实现如下:
cfg = [3, 56, 64, 64, 48, 240, 16, 64, 152, 32, 32, 152, 120, 104, 216, 368, 112, 32, 480, 112, 120, 504, 88, 104, 104, 240, 184, 368, 768, 200, 200, 640, 232, 192, 976, 248, 192, 760, 160, 208, 584, 208, 248, 968, 496, 224, 208, 416, 104, 104, 416, 104, 104, 416]
shortcut = [1, 10, 22, 40]
block_nums = [3, 4, 6, 3]
def Sample(x, num):
np.random.seed(2019)
batch_size, channel_num, height, width = x.data.size()
channel_index = np.random.choice(channel_num, num)
x = x[:, channel_index, :, :]
return x
class Bottleneck(nn.Module):
def __init__(self, in_planes, planes, stride=1, offset=1):
super(Bottleneck, self).__init__()
# pw
self.conv1 = nn.Conv2d(cfg[offset], cfg[offset+1], kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(cfg[offset+1])
# dw
self.conv2 = nn.Conv2d(cfg[offset+1], cfg[offset+2], kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(cfg[offset+2])
# pw
self.conv3 = nn.Conv2d(cfg[offset+2], cfg[offset+3], kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(cfg[offset+3])
self.shortcut = nn.Sequential()
if offset in shortcut:
p = shortcut.index(offset)
self.shortcut = nn.Sequential(
nn.Conv2d(cfg[offset], cfg[offset+3], kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(cfg[offset+3])
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
s = self.shortcut(x)
if s.data.size() != out.data.size():
s = Sample(s, out.data.size(1))
out += s
out = F.relu(out)
return out
可以看到,对于残差那一分支,采用的是sample采样的方法来使得通道数与瓶颈结构相同,之所以不对瓶颈结构中的卷积结果进行采样,是由于这样可以尽可能多地保留输入特征的信息。而瓶颈结构中多了offset参数用以标记当前的卷积的索引。