Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks - 模型压缩 - 3 - 代码学习,VGG16,Resnet

VGG16

run/vgg16/vgg16_prune_demo.py运行:

 python ./run/vgg16/vgg16_prune_demo.py --config ./run/vgg16/prune.json

报错:

Traceback (most recent call last):
  File "./run/vgg16/vgg16_prune_demo.py", line 16, in <module>
    from logger import logger
  File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 67, in <module>
    logger = Logger()
  File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 42, in __init__
    json.dump(cfg, fp)
  File "/anaconda3/envs/deeplearning/lib/python3.7/json/__init__.py", line 179, in dump
    for chunk in iterable:
  File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type Config is not JSON serializable

原因是无法序列化某些对象格式,因为我们这里使用了自定义的dotdict

解决办法:

将logger.py中的json.dump()改为:

            with open(self.cfgfile, 'w') as fp:
                json.dump(cfg, fp, cls=dotdict)

显式指定使用自定义序列化方法dotdict

再出错:

AssertionError: Torch not compiled with CUDA enabled

将prune.json中的cuda:true改为false

报错:

FileNotFoundError: [Errno 2] No such file or directory: './logs/vgg16_cifar10/ckp.160.torch'

这是因为我没有按照顺序运行,没有先运行:

CUDA_VISIBLE_DEVICES=0 python main.py --config ./run/vgg16/baseline.json

该命令会生成一个ckp.160.torch文件

 

所以我使用pytorch给的预训练文件,将vgg16_prune_demo.py中的:

def get_pack():
    set_seeds()
    pack = recover_pack()

    model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
    pack.net.module.load_state_dict(model_dict)

改成:

def get_pack():
    set_seeds()
    pack = recover_pack()
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)

然后查看此时的网络结果:

pack, GBNs = get_pack()
for name, child in pack.net.named_children():
    print(name)
    print(child)

print(GBNs)

 

后面运行出错:

Traceback (most recent call last):
  File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
    run()
  File "./run/vgg16/vgg16_prune_demo.py", line 112, in run
    pack, GBNs = get_pack()
  File "./run/vgg16/vgg16_prune_demo.py", line 29, in get_pack
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)
  File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for VGG:
    size mismatch for features.7.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).

这是因为应该使用的结构是vgg16_bn的结构,否则就没有bn层,改模型为https://download.pytorch.org/models/vgg16_bn-6c64b313.pth

 

又报错:

Traceback (most recent call last):
  File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
    run()
  File "./run/vgg16/vgg16_prune_demo.py", line 114, in run
    cloned, _ = clone_model(pack.net)
  File "./run/vgg16/vgg16_prune_demo.py", line 54, in clone_model
    gbns = GatedBatchNorm2d.transform(model.module)
  File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
    type(self).__name__, name))
AttributeError: 'VGG' object has no attribute 'module'

model.module改成model即可,因为我没有使用

    if cfg.base.multi_gpus: #设置了multi_gpus为False
        model = torch.nn.DataParallel(model)

 

 

仅仅根据代码说说原理

感觉看了所有的代码后其工作原理是这样的,拿vgg16_prune_demo.py的prune()函数举例子:

prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)

1)准备好了Tick-Tock

# 先所有数据迭代cfg.gbn.tock_epoch次
    prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

其实就相当于在原有模型上进行微调cfg.gbn.tock_epoch次

 

2)然后就循环进行Tick操作:

def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM):
    LOGS = []
    flops_save_points = set([30, 20, 10])
    iter_idx = 0

    pack.tick_trainset = pack.train_loader
    prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
    # 先所有数据迭代cfg.gbn.tock_epoch次
    prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    while True:
        left_filter = prune_agent.total_filters - prune_agent.pruned_filters
        num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值
        info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
        flops, params = eval_prune(pack)
        info.update({ #查看这次剪枝后的结果
            'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
            'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
        })
        LOGS.append(info)
        print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % 
            (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
        
        iter_idx += 1
        if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock
            print('Tocking:')
            prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

        flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少
        for point in [i for i in list(flops_save_points)]:
            if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point
                torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point))
                flops_save_points.remove(point)

        if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了
            break

Tick操作就是在计算分数,决定剪去BN层的哪些channels

 

3)开始进行Tick-Tock前的网络结构就是将BN层换成了GBN层:

def get_pack():
    set_seeds()
    pack = recover_pack()

    #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False)
    #pack.net.module.load_state_dict(model_dict)

    
    GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练
    for gbn in GBNs:
        gbn.extract_from_bn()
        
#     for name, child in pack.net.named_children():
#         print(name)
#         print(child)
        
    pack.optimizer = optim.SGD(
        pack.net.parameters() ,
        lr=2e-3,
        momentum=cfg.train.momentum,
        weight_decay=cfg.train.weight_decay,
        nesterov=cfg.train.nesterov
    )

    return pack, GBNs

GatedBatchNorm2d.transform(pack.net) 中的extract_from_bn()函数在bn层加入g参数,同时将其bias、weight参数进行更改,并freeze weight参数,这样训练时只有g参数会优化:

    def extract_from_bn(self):
        # freeze bn weight
        with torch.no_grad():
            self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10))
            self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
            self.bn.weight.set_(torch.ones_like(self.bn.weight))
            self.bn.weight.requires_grad = False

如论文中:

 

在这个基础上进行Tock操作其实就是在bn层加入g参数,并freeze weight参数的基础上使用整个训练数据集训练模型

 

4)然后进行prune操作:

info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数

其实就是进行Tick操作+prune操作

首先Tick操作是:

    def tick(self, lr, test):
        ''' Do Prune '''
        self.freeze_conv()
        info = self.recover(lr, test)
        self.restore_conv()
        return info

会freeze住卷积层的参数,所以tick训练时只会训练GBN层的g参数和全连接层的参数

接下来的就是剪枝prune操作:

然后接下来就是根据这个Tick训练的g计算每个bn层中filter的分数,一开始bn_mask(查看prune/universal.py文件中的类GatedBatchNorm2d定义)这个值全是1,即表示所有的filter都要,这样子self.score*self.bn_mask就能得到所有的filter的分数,然后再根据分数进行排序等操作来计算阈值分数值threshold,然后再根据阈值等信息得到一个self.mask的值,用这个值去更新self.bn_mask = mask * g.bn_mask,这样每个GBN层中的bn_mask值中为0就表示对应的filter是被删除的,1则表示该对应的filter留下

所以剪枝操作其实就是根据bn_mask的结果去剪枝,因为GatedBatchNorm2d类的forward操作中有:

    def forward(self, x): 
        x = self.bn(x) * self.g

        self.area[0] = x.shape[-1] * x.shape[-2]

        if self.bn_mask is not None:
            return x * self.bn_mask
        return x

因此在训练的时候,前向操作经过GBN层得到的结果就是x * self.bn_maskbn_mask为0对应的x的channels的值就会全为0,就相当于剪掉了这个filter

 

5)接下来就是根据上面的剪枝结果去对应地将卷积层和全连接层中的channels数和GBN层对应起来:

    _ = Conv2dObserver.transform(pack.net.module)
    pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)

主要就是将它们分别封装成Conv2dObserver和FinalLinearObserver

Conv2dObserver中就会有in_mask和out_mask两个参数,就是分别在训练的前向传播和后向传播中计算channels轴的和,最后为0则说明该轴已经被prune了:

    def _forward_hook(self, m, _in, _out):
        x = _in[0]
        self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)

    def _backward_hook(self, grad):
        self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
        new_grad = torch.ones_like(grad)
        return new_grad

    def forward(self, x):
        output = self.conv(x)
        noise = torch.zeros_like(output).normal_()
        output = output + noise
        if self.training:
            output.register_hook(self._backward_hook)
        return output

FinalLinearObserver也是同样的概念

 

6)然后就是observe和melt_all操作:

    Meltable.observe(pack, 0.001)
    Meltable.melt_all(pack.net)

observe感觉就是在将那些没有被换成GBN层的bn层的weight添加一个极小值(1e-3)、将relu层改成LeakyReLU并freeze bn层的参数,然后再进行训练,训练完之后再恢复原状(这里一直不太明白目的是啥)

突然明白这里是干嘛了,这里其实就是训练一遍,来计算Conv2dObserver和FinalLinearObserver中in_mask和out_mask的结果,然后用于melt_all

 

melt_all其实就是将所有的GBN、Conv2dObserver和FinalLinearObserver根据得到的in_mask和out_mask以及GBN中的self.bn_mask来恢复网络,删去不要的filter,只将对应的filter的参数赋值到新的网络结构中,调用的是这几个类中的melt()函数

 

7)最后再使用这个新的网络结构进行微调:

    _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

 

要自己将微调后的模型保存下来

1》仅保存模型:

torch.save(pack.net.module.state_dict(), os.path.join(saving_path, '30_finetune_state.pth'))

用.module是因为使用了:

model = torch.nn.DataParallel(model)

如果没有使用可以删掉

2》保存模型和网络结构:

torch.save(pack.net.module, os.path.join(saving_path, '30_finetune.pth'))

 

整个代码是:

import os
import sys

_r = os.getcwd().split('/')
_p = '/'.join(_r[:_r.index('gate-decorator-pruning')+1])
print('Change dir from %s to %s' % (os.getcwd(), _p))
os.chdir(_p)
sys.path.append(_p)

import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim

from config import cfg
from logger import logger
from main import set_seeds, recover_pack, adjust_learning_rate, _step_lr, _sgdr
from models import get_model
from utils import dotdict

from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver
from prune.utils import analyse_model, finetune

def get_pack():
    set_seeds()
    pack = recover_pack()

    #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False)
    #pack.net.module.load_state_dict(model_dict)

    
    GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练
    for gbn in GBNs:
        gbn.extract_from_bn()
        
#     for name, child in pack.net.named_children():
#         print(name)
#         print(child)
        
    pack.optimizer = optim.SGD(
        pack.net.parameters() ,
        lr=2e-3,
        momentum=cfg.train.momentum,
        weight_decay=cfg.train.weight_decay,
        nesterov=cfg.train.nesterov
    )

    return pack, GBNs
# get_pack()

def clone_model(net):
    model = get_model()
    gbns = GatedBatchNorm2d.transform(model)
    model.load_state_dict(net.state_dict())
    return model, gbns


def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module) #根据prune后的bn更改conv2d层
    cloned.module.classifier = FinalLinearObserver(cloned.module.classifier) #根据prune后的bn更改全连接层
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net) #根据此时的g恢复所有的参数
#     flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
    flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32))
    del cloned
    del cloned_pack
    
    return flops, params


def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM):
    LOGS = []
    flops_save_points = set([30, 20, 10])
    iter_idx = 0

    pack.tick_trainset = pack.train_loader
    prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
    # 先所有数据迭代cfg.gbn.tock_epoch次
    prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    while True:
        left_filter = prune_agent.total_filters - prune_agent.pruned_filters
        num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值
        info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
        flops, params = eval_prune(pack)
        info.update({ #查看这次剪枝后的结果
            'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
            'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
        })
        LOGS.append(info)
        print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % 
            (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
        
        iter_idx += 1
        if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock
            print('Tocking:')
            prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

        flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少
        for point in [i for i in list(flops_save_points)]:
            if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point
                torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point))
                flops_save_points.remove(point)

        if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了
            break


def run():
    pack, GBNs = get_pack()

    cloned, _ = clone_model(pack.net)
#     BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32).cuda()) #计算一开始预训练好的模型的Flops和内存
    BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32))
    print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
    print('%.3f M' % (BASE_PARAM / 1e6))
    del cloned

    prune(pack, GBNs, BASE_FLOPS, BASE_PARAM) # 进行Tick-Tock操作

    _ = Conv2dObserver.transform(pack.net.module)
    pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)
    Meltable.observe(pack, 0.001)
    Meltable.melt_all(pack.net)

    pack.optimizer = optim.SGD(
        pack.net.parameters(),
        lr=1,
        momentum=cfg.train.momentum,
        weight_decay=cfg.train.weight_decay,
        nesterov=cfg.train.nesterov
    )

    _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

run()
View Code

 

感觉比较重要的代码是prune/universal.py:

这里有转换bn层、conv层和FinalLinear层的类,还有Tick-Tock操作的类:

"""
 * Copyright (C) 2019 Zhonghui You
 * If you are using this code in your research, please cite the paper:
 * Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks, in NeurIPS 2019.
"""

import torch
import torch.nn as nn

import numpy as np
import uuid

OBSERVE_TIMES = 5
FINISH_SIGNAL = 'finish'

class Meltable(nn.Module):
    def __init__(self):
        super(Meltable, self).__init__()

    @classmethod
    def melt_all(cls, net):
        def _melt(modules):
            keys = modules.keys()
            for k in keys:
                if len(modules[k]._modules) > 0:
                    _melt(modules[k]._modules)
                if isinstance(modules[k], Meltable):
                    modules[k] = modules[k].melt() #根据此时的g恢复所有的参数

        _melt(net._modules)

    @classmethod
    def observe(cls, pack, lr):
        tmp = pack.train_loader
        if pack.tick_trainset is not None:
            pack.train_loader = pack.tick_trainset

        for m in pack.net.modules():
            if isinstance(m, nn.BatchNorm2d): #这个用来干嘛的??
                m.weight.data.abs_().add_(1e-3)

        def replace_relu(modules):
            keys = modules.keys()
            for k in keys:
                if len(modules[k]._modules) > 0:
                    replace_relu(modules[k]._modules)
                if isinstance(modules[k], nn.ReLU):
                    modules[k] = nn.LeakyReLU(inplace=True)
        replace_relu(pack.net._modules)

        count = 0
        def _freeze_bn(curr_iter, total_iter):
            for m in pack.net.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
            nonlocal count
            count += 1
            if count == OBSERVE_TIMES:
                return FINISH_SIGNAL
        info = pack.trainer.train(pack, iter_hook=_freeze_bn, update=False, mute=True)

        def recover_relu(modules):
            keys = modules.keys()
            for k in keys:
                if len(modules[k]._modules) > 0:
                    recover_relu(modules[k]._modules)
                if isinstance(modules[k], nn.LeakyReLU):
                    modules[k] = nn.ReLU(inplace=True)
        recover_relu(pack.net._modules)

        for m in pack.net.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.abs_().add_(-1e-3)

        pack.train_loader = tmp


class GatedBatchNorm2d(Meltable):
    def __init__(self, bn, minimal_ratio = 0.1):
        super(GatedBatchNorm2d, self).__init__()
        assert isinstance(bn, nn.BatchNorm2d)
        self.bn = bn
        self.group_id = uuid.uuid1()

        self.channel_size = bn.weight.shape[0]
        self.minimal_filter = max(1, int(self.channel_size * minimal_ratio))
        self.device = bn.weight.device
        self._hook = None

        self.g = nn.Parameter(torch.ones(1, self.channel_size, 1, 1).to(self.device), requires_grad=True)
        self.register_buffer('area', torch.zeros(1).to(self.device)) #记录此时输入的图像大小
        self.register_buffer('score', torch.zeros(1, self.channel_size, 1, 1).to(self.device))
        #这个值要么0要么1,根据得到的分数来得到self.masks,然后设置g.bn_mask.set_(mask * g.bn_mask),这个才是决定channels留下来与否的值
        self.register_buffer('bn_mask', torch.ones(1, self.channel_size, 1, 1).to(self.device)) 
        
        self.extract_from_bn()

    def set_groupid(self, new_id):
        self.group_id = new_id

    def extra_repr(self):
        return '%d -> %d | ID: %s' % (self.channel_size, int(self.bn_mask.sum()), self.group_id)

    def extract_from_bn(self):
        # freeze bn weight
        with torch.no_grad():
            self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10))
            self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
            self.bn.weight.set_(torch.ones_like(self.bn.weight))
            self.bn.weight.requires_grad = False

    def reset_score(self):
        self.score.zero_()

    def cal_score(self, grad):
        # used for hook
        self.score += (grad * self.g).abs()

    def start_collecting_scores(self):
        if self._hook is not None:
            self._hook.remove()

        self._hook = self.g.register_hook(self.cal_score)

    def stop_collecting_scores(self):
        if self._hook is not None:
            self._hook.remove()
            self._hook = None
    
    def get_score(self, eta=0.0): #eta表示什么?,如果为0,其实就是score = self.score * self.bn_mask
        # use self.bn_mask.sum() to calculate the number of input channel. eta should had been normed
        flops_reg = eta * int(self.area[0]) * self.bn_mask.sum()
        return ((self.score - flops_reg) * self.bn_mask).view(-1)

    def forward(self, x): # train时就会调用这个函数,输出与self.g相乘
        x = self.bn(x) * self.g

        self.area[0] = x.shape[-1] * x.shape[-2]

        if self.bn_mask is not None:
            return x * self.bn_mask #只留下bn_mask中值不为0的channels对应的x的值
        return x

    def melt(self): #训练完了,恢复参数的函数
        with torch.no_grad():
            mask = self.bn_mask.view(-1)
            replacer = nn.BatchNorm2d(int(self.bn_mask.sum())).to(self.bn.weight.device)
            replacer.running_var.set_(self.bn.running_var[mask != 0])
            replacer.running_mean.set_(self.bn.running_mean[mask != 0])
            replacer.weight.set_((self.bn.weight * self.g.view(-1))[mask != 0])
            replacer.bias.set_((self.bn.bias * self.g.view(-1))[mask != 0])
        return replacer

    @classmethod
    def transform(cls, net, minimal_ratio=0.1):
        r = []
        def _inject(modules):
            keys = modules.keys()
            for k in keys:
                if len(modules[k]._modules) > 0:
                    _inject(modules[k]._modules)
                if isinstance(modules[k], nn.BatchNorm2d):
                    modules[k] = GatedBatchNorm2d(modules[k], minimal_ratio) #将原来的BN层改成GBN层
                    r.append(modules[k])
        _inject(net._modules)
        return r


class FinalLinearObserver(Meltable):
    ''' assert was in the last layer. only input was masked '''
    def __init__(self, linear):
        super(FinalLinearObserver, self).__init__()
        assert isinstance(linear, nn.Linear)
        self.linear = linear
        self.in_mask = torch.zeros(linear.weight.shape[1]).to('cpu')
        self.f_hook = linear.register_forward_hook(self._forward_hook)
    
    def extra_repr(self):
        return '(%d, %d) -> (%d, %d)' % (
            int(self.linear.weight.shape[1]),
            int(self.linear.weight.shape[0]),
            int((self.in_mask != 0).sum()),
            int(self.linear.weight.shape[0]))

    def _forward_hook(self, m, _in, _out):
        x = _in[0]
        self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1)

    def forward(self, x):
        return self.linear(x)

    def melt(self):
        with torch.no_grad():
            replacer = nn.Linear(int((self.in_mask != 0).sum()), self.linear.weight.shape[0]).to(self.linear.weight.device)
            replacer.weight.set_(self.linear.weight[:, self.in_mask != 0])
            replacer.bias.set_(self.linear.bias)
        return replacer


class Conv2dObserver(Meltable):
    def __init__(self, conv):
        super(Conv2dObserver, self).__init__()
        assert isinstance(conv, nn.Conv2d)
        self.conv = conv
        self.in_mask = torch.zeros(conv.in_channels).to('cpu')
        self.out_mask = torch.zeros(conv.out_channels).to('cpu')
        self.f_hook = conv.register_forward_hook(self._forward_hook)

    def extra_repr(self):
        return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum()))
    
    def _forward_hook(self, m, _in, _out):
        x = _in[0]
        self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)

    def _backward_hook(self, grad):
        self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
        new_grad = torch.ones_like(grad)
        return new_grad

    def forward(self, x):
        output = self.conv(x)
        noise = torch.zeros_like(output).normal_()
        output = output + noise
        if self.training:
            output.register_hook(self._backward_hook)
        return output

    def melt(self):
        if self.conv.groups == 1:
            groups = 1
        elif self.conv.groups == self.conv.out_channels:
            groups = int((self.out_mask != 0).sum())
        else:
            assert False

        replacer = nn.Conv2d(
            in_channels = int((self.in_mask != 0).sum()),
            out_channels = int((self.out_mask != 0).sum()),
            kernel_size = self.conv.kernel_size,
            stride = self.conv.stride,
            padding = self.conv.padding,
            dilation = self.conv.dilation,
            groups = groups,
            bias = (self.conv.bias is not None)
        ).to(self.conv.weight.device)

        with torch.no_grad():
            if self.conv.groups == 1:
                replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0])
            else:
                replacer.weight.set_(self.conv.weight[self.out_mask != 0])
            if self.conv.bias is not None:
                replacer.bias.set_(self.conv.bias[self.out_mask != 0])
        return replacer
    
    @classmethod
    def transform(cls, net):
        r = []
        def _inject(modules):
            keys = modules.keys()
            for k in keys:
                if len(modules[k]._modules) > 0:
                    _inject(modules[k]._modules)
                if isinstance(modules[k], nn.Conv2d):
                    modules[k] = Conv2dObserver(modules[k])
                    r.append(modules[k])
        _inject(net._modules)
        return r

# -------------------------------------------------------------------------------------------------------

def get_gate_sparse_loss(masks, sparse_lambda):
    def _loss_hook(data, label, logits):
        loss = 0.0
        for gbn in masks:
            if isinstance(gbn, GatedBatchNorm2d):
                loss += gbn.g.abs().sum()
        return sparse_lambda * loss

    return _loss_hook

class IterRecoverFramework():
    def __init__(self, pack, masks, sparse_lambda=1e-5, flops_eta=0.0, minium_filter=10):
        self.pack = pack
        self.masks = masks
        self.sparse_loss_hook = get_gate_sparse_loss(masks, sparse_lambda) #计算tock的损失的后半部分
        self.logs = []
        # minium_filter would be delete
        # self.minium_filter = minium_filter
        self.sparse_lambda = sparse_lambda
        self.flops_eta = flops_eta
        self.eta_scale_factor = 1.0

        self.total_filters = sum([m.bn.weight.shape[0] for m in masks])
        self.pruned_filters = 0

    def recover(self, lr, test):
        for gbn in self.masks:
            if isinstance(gbn, GatedBatchNorm2d):
                gbn.reset_score()
                gbn.start_collecting_scores()

        for g in self.pack.optimizer.param_groups:
            g['lr'] = lr

        tmp = self.pack.train_loader
        self.pack.train_loader = self.pack.tick_trainset #使用训练集的子集
        info = self.pack.trainer.train(self.pack) #执行Tick,只更新gate φ和最后的线性层参数
        self.pack.train_loader = tmp

        if test:
            info.update(self.pack.trainer.test(self.pack))

        info.update({'LR': lr})

        for gbn in self.masks:
            if isinstance(gbn, GatedBatchNorm2d):
                gbn.stop_collecting_scores()
        
        return info

    def get_threshold(self, status, num):
        '''
            input score list from layers, and the number of filter to prune
        '''
        total_filters, left_filters = 0, 0
        filtered_score_list = []
        
        for group_id, v in status.items():
            total_filters += len(v['score']) * v['count'] #count>1,说明分group了,对应channels有着相同的分数
            left_filters += int((v['score'] != 0).sum()) * v['count']

            sorted_score = np.sort(v['score'])[:-v['minimal']] #-v['minimal']之后的分数是不要的,按比例抛弃
            filtered_score = sorted_score[sorted_score != 0]
            for i in range(v['count']):
                filtered_score_list.append(filtered_score)#因为相同组的channels分数是相同的,所以append v['count']次

        scores = np.concatenate(filtered_score_list) #将所有GBN中的channels的分数串联在一起
        threshold = np.sort(scores)[num] #然后再排序,取num索引的值作为阈值
        to_prune = int((scores <= threshold).sum()) #分数小于这个阈值的channels也prune

        info = {'left': left_filters, 'to_prune': to_prune, 'total_pruned_ratio': (total_filters - left_filters + to_prune) / total_filters}
        return threshold, info

    #这里的作用就是根据训练得到的分数和阈值计算出mask,用于更改之前的g.bn_mask,这个才是决定channels的值石佛耦留下来的值
    # 因为在BGN的forward中输出的x为x * self.bn_mask
    def set_mask(self, status, threshold): #这里的self.masks是GBNs
        for group_id, v in status.items():
            hard_threshold = float(np.sort(v['score'])[-v['minimal']]) #根据按比例得到的v['minimal'] = max(1, int(self.channel_size * minimal_ratio)),得到该位置的分数,说明小于这个分数的channels是一定要prune的
            hard_mask = v['score'] >= hard_threshold #留下的channels更多
            soft_mask = v['score'] > threshold #留下的channels少
            v['mask'] = (hard_mask + soft_mask)

        with torch.no_grad():
            for g in self.masks:
                if g.group_id in status:
                    mask = torch.from_numpy(status[g.group_id]['mask'].astype('float32')).to(g.device).view(1, -1, 1, 1)
                    g.bn_mask.set_(mask * g.bn_mask)

    def freeze_conv(self): #Tick训练时不更新conv的参数,所以freeze它们
        self._status = {}
        for m in self.pack.net.modules():
            if isinstance(m, nn.Conv2d):
                for p in m.parameters():
                    self._status[id(p)] = p.requires_grad
                    p.requires_grad = False

    def restore_conv(self):
        for m in self.pack.net.modules():
            if isinstance(m, nn.Conv2d):
                for p in m.parameters():
                    p.requires_grad = self._status[id(p)]

    def tock(self, lr_min=0.001, lr_max=0.01, tock_epoch = 20, mute=False, acc_step=1): #损失有一个额外的sparse loss,所以loss_hook有值,训练改变所有的参数,只是没有计算score,参数g的作用就是用于计算score
        logs = []
        epoch = 0
        T = tock_epoch
        def iter_hook(curr_iter, total_iter):
            total = T * total_iter
            half = total / 2
            itered = epoch * total_iter + curr_iter
            if itered < half:
                _iter = epoch * total_iter + curr_iter
                _lr = (1- _iter / half) * lr_min + (_iter / half) * lr_max
            else:
                _iter = (epoch - T/2) * total_iter + curr_iter
                _lr = (1- _iter / half) * lr_max + (_iter / half) * lr_min
            
            for g in self.pack.optimizer.param_groups:
                g['lr'] = max(_lr, 0)
                # g['momentum'] = 0.9
        
        for i in range(T):# 迭代T次
            info = self.pack.trainer.train(self.pack, loss_hook = self.sparse_loss_hook, iter_hook = iter_hook, acc_step=acc_step)
            info.update(self.pack.trainer.test(self.pack))
            info.update({'LR': self.pack.optimizer.param_groups[0]['lr']})
            epoch += 1
            if not mute:
                # print('Tock - %d,\t Test Loss: %.4f,\t Test Acc: %.2f, Final LR: %.5f' % (i, info['test_loss'], info['acc@1'], info['LR']))
                print('Tock - %d,\t Test Loss: %.4f,\t Test age_correct Acc: %.2f, Test gender_correct Acc: %.2f, Final LR: %.5f' % (i, info['test_loss'], info['age_correct'], info['gender_correct'], info['LR']))
            logs.append(info)
        return logs

    def tick(self, lr, test):
        ''' Do Prune '''
        self.freeze_conv()
        info = self.recover(lr, test)
        self.restore_conv()
        return info

    def prune(self, num, tick=False, lr=0.01, test=True):
        info = {}
        if tick:
            info = self.tick(lr, test)

            area = []
            for g in self.masks:
                area.append(int(g.area[0]))
            self.eta_scale_factor = min(area)

        status = {}
        for g in self.masks:
            if g.group_id in status:
                # assert the gbn in same group has the same channel size
                status[g.group_id]['score'] += g.get_score(self.flops_eta / self.eta_scale_factor).cpu().data.numpy()
                status[g.group_id]['count'] += 1
            else:
                status[g.group_id] = {
                    'score': g.get_score(self.flops_eta / self.eta_scale_factor).cpu().data.numpy(),
                    'minimal': g.minimal_filter,
                    'count': 1,
                    'mask': None
                }
        
        threshold, r = self.get_threshold(status, num)
        info.update(r)
        threshold = float(threshold)
        self.set_mask(status, threshold)
        if test:
            info.update({'after_prune_test_age_acc': self.pack.trainer.test(self.pack)['age_correct']})
            info.update({'after_prune_test_gender_acc': self.pack.trainer.test(self.pack)['gender_correct']})
        self.logs.append(info)
        self.pruned_filters = self.total_filters - info['left']
        info['total'] = self.total_filters
        return info
View Code

 

 

Resnet

感觉Resnet和VGG的主要差别在与Resnet有侧枝,所以需要对BN层分组:

1)resnet-56/resnet56_prune.ipynb

一步步向下运行:

GBNs = GatedBatchNorm2d.transform(pack.net)#GatedBatchNorm2d初始化时调用的self.extract_from_bn() 是用于一开始最外层的bn层的
print(GBNs) #extra_repr(self)函数设置了返回的额外内容,此时的ID是各不相同的
for gbn in GBNs: #这个再一层层地初始化
    gbn.extract_from_bn() #bn层的权重训练时不变,只有g变

返回:

[GatedBatchNorm2d(
  16 -> 16 | ID: 94d38830-2d48-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 94d5b330-2d48-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 94d5bb32-2d48-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), 
...

 

import uuid

def bottleneck_set_group(net): #将3个layers的bn分成3个组
    layers = [
        net.module.layer1,
        net.module.layer2,
        net.module.layer3
    ]
    for m in layers:
        masks = []
        if m == net.module.layer1: #将layer1这个分组之前的一个bn层添加进来
            masks.append(pack.net.module.bn1)
        for mm in m.modules():
            if mm.__class__.__name__ == 'BasicBlock':
                if len(mm.shortcut._modules) > 0: #说明shortcut是resnet两个channels不同的layer层的过渡操作
                    masks.append(mm.shortcut._modules['1']) #这里面也有一个bn层
                masks.append(mm.bn2) #bn1不加吗?

        group_id = uuid.uuid1()
        for mk in masks: #masks中的每个值都是一个
            mk.set_groupid(group_id) #这个是GatedBatchNorm2d中设置的函数,仅将bn2的id更改成新的
        print(masks)

bottleneck_set_group(pack.net)

返回:

[GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)]
[GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)]
[GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
  64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)]
View Code

可见分成3组,每组的GBN的ID是想用的

此时的网络结构是:

for name, child in pack.net.named_children():
    print(name)
    print(child)

通过将同一个group的id设置为同一个id来说明它们是同一个group的,这样同一个组的BN的bn_mask值是相同的,即它们的channels数也是相同的,这样就能连起来了

 

然后克隆了一个一样的模型来计算模型的FLOPs和参数大小:

def clone_model(net):
    model = get_model()
    gbns = GatedBatchNorm2d.transform(model.module)
    model.load_state_dict(net.state_dict())
    return model, gbns

cloned, _ = clone_model(pack.net)
#BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32).cuda())
BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32))
print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
print('%.3f M' % (BASE_PARAM / 1e6))

for name, child in cloned.named_children():
    print(name)
    print(child)
    
del cloned

除了id不同其他一致,而且del cloned删除了该模型,说明该模型后面没用

返回:

127.932 MFLOPS
0.856 M

上面的clone_model(net)操作其实是为了下面的函数eval_prune(pack)服务的:

def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    cloned.module.linear = FinalLinearObserver(cloned.module.linear)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net)
    flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
    del cloned
    del cloned_pack
    
    return flops, params

其实就是用于计算此时的模型的flops, params,用来与原始模型比较,查看此时压缩了多少

 

接下来就是测试此时的模型,看看效果:

pack.trainer.test(pack)

返回:

{'test_loss': 0.31250936849207817, 'acc@1': 92.92919303797468}

 

然后就是tick-tock操作:

pack.tick_trainset = pack.train_loader
prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)

LOGS = []
flops_save_points = set([40, 38, 35, 32, 30]) #当压缩到原来模型的40%、38%...时保存模型

iter_idx = 0
# 先进行一个tock,训练tock_epoch次,查看
prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
while True:
    left_filter = prune_agent.total_filters - prune_agent.pruned_filters
    num_to_prune = int(left_filter * cfg.gbn.p)
    info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #进行tick操作
    flops, params = eval_prune(pack) #计算此时的模型
    info.update({
        'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
        'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
    })
    LOGS.append(info)
    print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % 
          (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
    
    iter_idx += 1
    if iter_idx % cfg.gbn.T == 0: #即每gbn.T=10次tick操作后进行gbn.tock_epoch次tock操作
        print('Tocking:')
        prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

    flops_ratio = flops/BASE_FLOPS * 100 #计算此时的模型占原来模型的百分比
    for point in [i for i in list(flops_save_points)]: #当压缩到原来模型的40%、38%...时保存模型
        if flops_ratio <= point:
            torch.save(pack.net.module.state_dict(), './logs/resnet56_cifar10_ticktock/%s.ckp' % str(point))
            flops_save_points.remove(point)

    if len(flops_save_points) == 0: #当压缩到30%时就停止压缩
        break

 

其finetune操作在resnet-56/finetune.ipynb

核心就是先将模型根据bn_mask值剪枝,不仅剪BN层,还要见Conv2d层和FinalLinear层,即:

GBNs = GatedBatchNorm2d.transform(pack.net)
for gbn in GBNs:
    gbn.extract_from_bn()

_ = Conv2dObserver.transform(pack.net.module)
pack.net.module.linear = FinalLinearObserver(pack.net.module.linear)
Meltable.observe(pack, 0.001)
Meltable.melt_all(pack.net) # 剪枝

然后进行finetune:

_ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

 

 

这里有一个特别的写法,记录一下

保证路径为主路径的方法:python保证路径为主路径的方法

posted @ 2020-01-10 18:29  慢行厚积  阅读(881)  评论(0编辑  收藏  举报