Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks - 模型压缩 - 2 - 代码学习
https://github.com/youzhonghui/gate-decorator-pruning
1.utils.py
class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__
继承dict字典,实际上还是dict
2.loader/__init__.py
import torchvision from torch.utils.data import DataLoader from torchvision import transforms from PIL import Image from config import cfg from loader.cifar10 import get_cifar10 from loader.cifar100 import get_cifar100 from loader.imagenet import get_imagenet def get_loader(): pair = { # 根据设置的参数的数据集名字来选择使用哪个数据集 'cifar10': get_cifar10, 'cifar100': get_cifar100, 'imagenet': get_imagenet } return pair[cfg.data.type]()
选择使用那个数据集,对应的config设置为:
from config import parse_from_dict parse_from_dict({ ... "data": { "type": "cifar10", #这个即使用的数据集名字 "shuffle": True, "batch_size": 128, "test_batch_size": 128, "num_workers": 4 ...
3.models/__init__.py
import torch from config import cfg def get_vgg16_for_cifar(): from models.cifar.vgg import VGG return VGG('VGG16', cfg.model.num_class) def get_resnet50_for_imagenet(): from models.imagenet.resnet50 import Resnet50 return Resnet50(cfg.model.num_class) def get_resnet56(): from models.cifar.resnet56 import resnet56 return resnet56(cfg.model.num_class) def get_model(): pair = {# 根据设置的参数的模型名字来选择使用哪个模型 'cifar.vgg16': get_vgg16_for_cifar, 'resnet50': get_resnet50_for_imagenet, 'cifar.resnet56': get_resnet56 } model = pair[cfg.model.name]() if cfg.base.checkpoint_path != '': #是否有训练好的预训练模型 print('restore checkpoint: ' + cfg.base.checkpoint_path) model.load_state_dict(torch.load(cfg.base.checkpoint_path, map_location='cpu' if not cfg.base.cuda else 'cuda')) if cfg.base.cuda: #单个GPU model = model.cuda() if cfg.base.multi_gpus: #多个GPU model = torch.nn.DataParallel(model) return model
选择使用哪个模型进行分类,并设置是使用cpu还是GPU,有预训练模型就加载预训练模型
对应的config设置为:
from config import parse_from_dict parse_from_dict({ ... "model": { "name": "cifar.resnet56", "num_class": 10, "pretrained": False },
4.loss.py
import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import numpy as np from config import cfg def get_criterion(): pair = { #设置使用的损失函数 'softmax': nn.CrossEntropyLoss() } assert (cfg.loss.criterion in pair) criterion = pair[cfg.loss.criterion] return criterion
使用交叉熵损失函数
对应的config设置为:
from config import parse_from_dict parse_from_dict({ ... "loss": { "criterion": "softmax"
5. config.py
import argparse import json from utils import dotdict def make_as_dotdict(obj): #从dict改成dotdict格式 if type(obj) is dict: obj = dotdict(obj) for key in obj: if type(obj[key]) is dict: obj[key] = make_as_dotdict(obj[key]) return obj def parse(): print('Parsing config file...') parser = argparse.ArgumentParser(description="config") parser.add_argument( "--config", type=str, default="configs/base.json", help="Configuration file to use" ) cli_args = parser.parse_args() with open(cli_args.config) as fp: config = make_as_dotdict(json.loads(fp.read())) print(json.dumps(config, indent=4, sort_keys=True)) return config class Singleton(object): _instance = None def __new__(cls, *args, **kw): if not cls._instance: cls._instance = super(Singleton, cls).__new__(cls, *args, **kw) return cls._instance class Config(Singleton): def __init__(self): self._cfg = dotdict({}) try: self._cfg = parse() except: pass def __getattr__(self, name): if name == '_cfg': super().__setattr__(name) else: return self._cfg.__getattr__(name) def __setattr__(self, name, val): if name == '_cfg': super().__setattr__(name, val) else: self._cfg.__setattr__(name, val) def __delattr__(self, name): #del删除元素时调用 return self._cfg.__delitem__(name) def copy(self, new_config): self._cfg = make_as_dotdict(new_config) cfg = Config() def parse_from_dict(d): #将dict换成dotdict global cfg assert type(d) == dict cfg.copy(d)
设置参数
但是不太明白为什么要弄成dotdict格式
这个函数在后面进行prune和finetune的时候会调用来设置参数信息,如:
from config import parse_from_dict parse_from_dict({ "base": { "task_name": "resnet56_cifar10_ticktock", "cuda": True, "seed": 0, "checkpoint_path": "", "epoch": 0, "multi_gpus": True, "fp16": False }, "model": { "name": "cifar.resnet56", "num_class": 10, "pretrained": False }, "train": { "trainer": "normal", "max_epoch": 160, "optim": "sgd", "steplr": [ [80, 0.1], #step>=80时,学习率都设置为0.1 [120, 0.01], [160, 0.001] # 120<step<=160时将学习率设置为0.001 ], "weight_decay": 5e-4, "momentum": 0.9, "nesterov": False }, "data": { "type": "cifar10", "shuffle": True, "batch_size": 128, "test_batch_size": 128, "num_workers": 4 }, "loss": { "criterion": "softmax" }, "gbn": { "sparse_lambda": 1e-3, "flops_eta": 0, "lr_min": 1e-3, "lr_max": 1e-2, "tock_epoch": 10, "T": 10, "p": 0.002 } }) from config import cfg
6.
trainer/__init__.py
from trainer.normal import NormalTrainer
from config import cfg
def get_trainer():
pair = {
'normal': NormalTrainer
}
assert (cfg.train.trainer in pair)
return pair[cfg.train.trainer]()
设置使用的训练train()、测试test()函数所在位置
扩展:
#coding:utf-8 import torch if __name__ == '__main__': a = torch.FloatTensor([[3, 14, 15, 13], [5,4,15,7]]).t() b = torch.FloatTensor([[3, 3, 3, 3], [5,5,5,5]]).t() correct = a.eq(b) print(correct) print(correct[:1]) print(correct[:1].view(-1)) print(correct[:1].view(-1).float()) correct_k = correct[:1].view(-1).float().sum(0, keepdim=True) print(correct_k)
返回:
tensor([[ True, True], [False, False], [False, False], [False, False]]) tensor([[True, True]]) tensor([True, True]) tensor([1., 1.]) tensor([2.])
trainer/normal.py:
from time import time import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as F from tqdm import tqdm import numpy as np from config import cfg FINISH_SIGNAL = 'finish' def accuracy(output, target, topk=(1,)): #计算分类的准确度 """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) #看是top-1还是top-5 batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) # 从输出中得到前maxk个大的预测结果的索引值,大小为(batch_size, maxk) pred = pred.t() # 转置成(maxk, batch_size) # target从(batch_size, 1) -> (1, batch_size) -> (maxk, batch_size) # 然后与pred对比看是否相等,每个batch_size最多只有一个相等,所以correct中true的个数最大值为batch_size # correct为(maxk, batch_size),值为 correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) # 得到百分比的正确率 return res class NormalTrainer(): def __init__(self): self.use_cuda = cfg.base.cuda def test(self, pack, topk=(1,)): # 测试 pack.net.eval() loss_acc, correct, total = 0.0, 0.0, 0.0 hub = [[] for i in range(len(topk))] for data, target in pack.test_loader: if self.use_cuda: data, target = data.cuda(), target.cuda() with torch.no_grad(): #不后向传播 output = pack.net(data) loss_acc += pack.criterion(output, target).data.item() #计算损失 acc = accuracy(output, target, topk) # 准确率 for acc_idx, score in enumerate(acc): hub[acc_idx].append(score[0].item()) loss_acc /= len(pack.test_loader) # 最后得到的平均损失 info = { 'test_loss': loss_acc } for acc_idx, k in enumerate(topk): info['acc@%d' % k] = np.mean(hub[acc_idx]) #top-1,top-5等准确率 return info def train(self, pack, loss_hook=None, iter_hook=None, update=True, mute=False, acc_step=1): #训练,mute即是否打印info pack.net.train() loss_acc, correct_acc, total = 0.0, 0.0, 0.0 begin = time() pack.optimizer.zero_grad() with tqdm(total=len(pack.train_loader), disable=mute) as pbar: total_iter = len(pack.train_loader) #总迭代次数 for cur_iter, (data, label) in enumerate(pack.train_loader): if iter_hook is not None: signal = iter_hook(cur_iter, total_iter) if signal == FINISH_SIGNAL: #结束标志 break if self.use_cuda: data, label = data.cuda(), label.cuda() data = Variable(data, requires_grad=False) label = Variable(label) logits = pack.net(data) loss = pack.criterion(logits, label) if loss_hook is not None: additional = loss_hook(data, label, logits) loss += additional loss = loss / acc_step loss.backward() if (cur_iter + 1) % acc_step == 0: if update: pack.optimizer.step() pack.optimizer.zero_grad() loss_acc += loss.item() pbar.update(1) info = { 'train_loss': loss_acc / len(pack.train_loader), 'epoch_time': time() - begin } return info
一个train()即跑完一次所有数据就结束了,即enumerate(pack.train_loader)完就完了
7.main.py
""" * Copyright (C) 2019 Zhonghui You * If you are using this code in your research, please cite the paper: * Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks, in NeurIPS 2019. """ import torch import torch.nn as nn import torch.optim as optim import numpy as np import random import math from loader import get_loader from models import get_model from trainer import get_trainer from loss import get_criterion from utils import dotdict from config import cfg from logger import logger def _sgdr(epoch): lr_min, lr_max = cfg.train.sgdr.lr_min, cfg.train.sgdr.lr_max restart_period = cfg.train.sgdr.restart_period _epoch = epoch - cfg.train.sgdr.warm_up while _epoch/restart_period > 1.: _epoch = _epoch - restart_period restart_period = restart_period * 2. radians = math.pi*(_epoch/restart_period) return lr_min + (lr_max - lr_min) * 0.5*(1.0 + math.cos(radians)) def _step_lr(epoch): v = 0.0 for max_e, lr_v in cfg.train.steplr: #max_e是到这个step的学习率都是lr_v v = lr_v if epoch <= max_e: break return v def get_lr_func(): if cfg.train.steplr is not None: return _step_lr elif cfg.train.sgdr is not None: return _sgdr else: assert False def adjust_learning_rate(epoch, pack): #设置使用的优化器,并设置学习率调节函数,以及更新学习率 if pack.optimizer is None: if cfg.train.optim == 'sgd' or cfg.train.optim is None: pack.optimizer = optim.SGD( pack.net.parameters(), lr=1, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov ) else: print('WRONG OPTIM SETTING!') assert False pack.lr_scheduler = optim.lr_scheduler.LambdaLR(pack.optimizer, get_lr_func()) pack.lr_scheduler.step(epoch) return pack.lr_scheduler.get_lr() def recover_pack(): train_loader, test_loader = get_loader() pack = dotdict({ 'net': get_model(), 'train_loader': train_loader, 'test_loader': test_loader, 'trainer': get_trainer(), 'criterion': get_criterion(), 'optimizer': None, 'lr_scheduler': None }) adjust_learning_rate(cfg.base.epoch, pack) return pack def set_seeds(): #用来保证代码中随机数每次都一样 torch.manual_seed(cfg.base.seed) if cfg.base.cuda: torch.cuda.manual_seed_all(cfg.base.seed) torch.backends.cudnn.deterministic = True if cfg.base.fp16: torch.backends.cudnn.enabled = True # torch.backends.cudnn.benchmark = True np.random.seed(cfg.base.seed) random.seed(cfg.base.seed) def main(): set_seeds() #设置中设置的"seed": 0,就是用在这的 pack = recover_pack() #设置各个参数和使用的模型、数据等 for epoch in range(cfg.base.epoch + 1, cfg.train.max_epoch + 1): lr = adjust_learning_rate(epoch, pack) # 更新lr info = pack.trainer.train(pack) #训练模型,得到损失和准确率等信息 info.update(pack.trainer.test(pack)) #加入测试时的损失和准确率等信息 info.update({'LR': lr}) #记录此时的lr print(epoch, info) logger.save_record(epoch, info) #写入日志 if epoch % cfg.base.model_saving_interval == 0: logger.save_network(epoch, pack.net) # 保存网络 if __name__ == '__main__': main()
8.logger.py
import torch from config import cfg import os import json import numpy as np class MetricsRecorder(): def __init__(self): self.rec = {} def add(self, pairs): for key, val in pairs.items(): if key not in self.rec: self.rec[key] = [] self.rec[key].append(val) def mean(self): r = {} for key, val in self.rec.items(): r[key] = np.mean(val) return r class Logger(): def __init__(self): self.base_path = './logs/' + cfg.base.task_name self.logfile = self.base_path + '/log.json' self.cfgfile = self.base_path + '/cfg.json' if not os.path.isdir(self.base_path): os.makedirs(self.base_path, exist_ok=True) with open(self.logfile, 'w') as fp: json.dump({}, fp) #初始化时日志信息为空 with open(self.cfgfile, 'w') as fp: json.dump(cfg, fp) #初始化时配置信息即config信息 def save_record(self, epoch, record): #保存运行过程中训练和测试的损失和准确率等信息,以当前的epoch为索引 with open(self.logfile) as fp: log = json.load(fp) log[str(epoch)] = record with open(self.logfile, 'w') as fp: json.dump(log, fp) def save_network(self, epoch, network): saving_path = self.base_path + '/ckp.%d.torch' % epoch print('saving model ...') if type(network) is torch.nn.DataParallel: torch.save(network.module.state_dict(), saving_path) else: torch.save(network.state_dict(), saving_path) cfg.base.epoch = epoch cfg.base.checkpoint_path = saving_path with open(self.cfgfile, 'w') as fp: # 保存新的配置信息 json.dump(cfg, fp) logger = None if logger is None: logger = Logger()
相应的文件将会根据任务名字,即设置:
from config import parse_from_dict parse_from_dict({ "base": { "task_name": "resnet18", #任务名字
在./logs文件夹下创建同名文件夹存储log.json和cfg.json文件,save_record()就是将中间信息保存在这,调用save_network()也会将模型保存在该文件夹中
user@jiayuan:/opt/.../gate-decorator-pruning/logs/resnet18$ ls
cfg.json log.json
接下来就是prune和finetune了,重要
9.prune/utils.py
#coding:utf-8 import os if __name__ == '__main__': print(os.devnull) #/dev/null
代码:
import torch import torch.nn as nn import os, contextlib from thop import profile def analyse_model(net, inputs): # silence with open(os.devnull, 'w') as devnull: #os.devnull对于Linux为/dev/null with contextlib.redirect_stdout(devnull):#标准输出已经重定向到了 /dev/null flops, params = profile(net, (inputs, )) #估算PyTorch模型的FLOPs模块 return flops, params def finetune(pack, lr_min, lr_max, T, mute=False): #T即finetune_epoch,即40轮迭代 logs = [] epoch = 0 def iter_hook(curr_iter, total_iter): #作为train的iter_hook参数传入 total = T * total_iter #total_iter即dataloader中有多少批batch_size,所以整个finetune跑total个batch_size half = total / 2 itered = epoch * total_iter + curr_iter #curr_iter即一个epoch中,数据跑到了第curr_iter个batch_size,现在的总batch_size数为itered if itered < half: #当小于一半时,学习率这么算 _iter = epoch * total_iter + curr_iter _lr = (1- _iter / half) * lr_min + (_iter / half) * lr_max else: # 当大于或等于一半时,学习率这么算,这两个的差别就是lr_max和lr_min的前后位置不同,大概意思是相同的 _iter = (epoch - T/2) * total_iter + curr_iter _lr = (1- _iter / half) * lr_max + (_iter / half) * lr_min for g in pack.optimizer.param_groups: g['lr'] = max(_lr, 0) g['momentum'] = 0.0 for i in range(T): #训练40个epoch info = pack.trainer.train(pack, iter_hook = iter_hook) info.update(pack.trainer.test(pack)) info.update({'LR': pack.optimizer.param_groups[0]['lr']}) epoch += 1 if not mute: #是否打印损失和精确度等信息 print(info) logs.append(info) return logs
这里的微调操作其实就是论文中的:
它跟tick-tock中的tock的差别在于tock中使用的还是GBN,且训练次数比较少,一半就10次;而finetune操作是在整个模型都prune后的小模型中训练,GBN都换回BN,且训练次数也比较多
10.prune/universal.py
扩展
1)uuid库:
UUID: 通用唯一标识符 ( Universally Unique Identifier ), 对于所有的UUID它可以保证在空间和时间上的唯一性. 它是通过MAC地址, 时间戳, 命名空间, 随机数, 伪随机数来保证生成ID的唯一性, 有着固定的大小( 128 bit ). 它的唯一性和一致性特点使得可以无需注册过程就能够产生一个新的UUID. UUID可以被用作多种用途, 既可以用来短时间内标记一个对象, 也可以可靠的辨别网络中的持久性对象.
为什么要使用UUID?
很多应用场景需要一个id, 但是又不要求这个id 有具体的意义, 仅仅用来标识一个对象. 常见的例子有数据库表的id 字段. 另一个例子是前端的各种UI库, 因为它们通常需要动态创建各种UI元素, 这些元素需要唯一的id , 这时候就需要使用UUID了.
#coding:utf-8 import uuid if __name__ == '__main__': print(uuid.uuid1()) #7b24099a-27ae-11ea-b076-00e04c6841ff
其实这个库主要是用于像resnet这样的网络中有侧枝shortcut的情况,是分组使用的,即同一个Group的ID是相同的。像VGG这样的网络每个GBN层的ID是不同的
2)nn.Parameter
#coding:utf-8 import torch.nn as nn import torch if __name__ == '__main__': g = nn.Parameter(torch.ones(1, 3, 1, 1), requires_grad=True) print(g)
返回:
Parameter containing: tensor([[[[1.]], [[1.]], [[1.]]]], requires_grad=True)
使用nn.Parameter的目的是将一个不可训练的类型Tensor
转换成可以训练的类型parameter
并将这个parameter
绑定到这个module
里面(net.parameter()
中就有这个绑定的parameter
,所以在参数优化的时候可以进行优化的),所以经过类型转换这个值就
变成了模型的一部分,成为了模型中根据训练可以改动的参数了
模型中的bias和weight都是nn.Parameter,可用于训练,并实现优化;Variable则是作为模型的输入
buffers()返回一个模块缓冲区的迭代器,其保存的是模型中每次前向传播需用到上一次前向传播的结果,作为持久状态的值,如BatchNorm2d()中使用的均值和方差值,其随着BatchNorm2d()中参数的变化而变化
3)
所以GatedBatchNorm2d代码中初始化中有设置参数:
def extract_from_bn(self):
# freeze bn weight
with torch.no_grad():
self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10)) #将self.bn.bias / self.bn.weight的值保持在[-10, 10],小于-10的即改为-10,大于10的即改为10
self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
self.bn.weight.set_(torch.ones_like(self.bn.weight)) #torch.ones_like(input)相当于torch.ones(input.size())
self.bn.weight.requires_grad = False
Φ就是g,β就是bn.bias,γ就是self.bn.weight
prune后,得到应该截取掉的filter,变回来的代码:
def melt(self): with torch.no_grad(): mask = self.bn_mask.view(-1) #转成列表, mask中有channels个值,值为0说明该channel被prune了 replacer = nn.BatchNorm2d(int(self.bn_mask.sum())).to(self.bn.weight.device) replacer.running_var.set_(self.bn.running_var[mask != 0]) #BatchNorm2d中的方差 replacer.running_mean.set_(self.bn.running_mean[mask != 0]) #BatchNorm2d中的均值 replacer.weight.set_((self.bn.weight * self.g.view(-1))[mask != 0]) replacer.bias.set_((self.bn.bias * self.g.view(-1))[mask != 0]) return replacer
整个代码:
import torch import torch.nn as nn import numpy as np import uuid OBSERVE_TIMES = 5 FINISH_SIGNAL = 'finish' class Meltable(nn.Module): def __init__(self): super(Meltable, self).__init__() @classmethod def melt_all(cls, net): def _melt(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: _melt(modules[k]._modules) if isinstance(modules[k], Meltable): modules[k] = modules[k].melt() _melt(net._modules) @classmethod def observe(cls, pack, lr): tmp = pack.train_loader if pack.tick_trainset is not None: pack.train_loader = pack.tick_trainset #数据集 for m in pack.net.modules(): if isinstance(m, nn.BatchNorm2d): m.weight.data.abs_().add_(1e-3) def replace_relu(modules): #将relu函数换成LeakyReLU函数 keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: replace_relu(modules[k]._modules) if isinstance(modules[k], nn.ReLU): modules[k] = nn.LeakyReLU(inplace=True) replace_relu(pack.net._modules) count = 0 def _freeze_bn(curr_iter, total_iter): for m in pack.net.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() nonlocal count count += 1 if count == OBSERVE_TIMES: return FINISH_SIGNAL info = pack.trainer.train(pack, iter_hook=_freeze_bn, update=False, mute=True) #优化器不优化了 def recover_relu(modules): #将LeakyReLU函数换成relu函数 keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: recover_relu(modules[k]._modules) if isinstance(modules[k], nn.LeakyReLU): modules[k] = nn.ReLU(inplace=True) recover_relu(pack.net._modules) for m in pack.net.modules(): if isinstance(m, nn.BatchNorm2d): m.weight.data.abs_().add_(-1e-3) # 变回来 pack.train_loader = tmp class GatedBatchNorm2d(Meltable): def __init__(self, bn, minimal_ratio = 0.1): super(GatedBatchNorm2d, self).__init__() assert isinstance(bn, nn.BatchNorm2d) self.bn = bn self.group_id = uuid.uuid1() self.channel_size = bn.weight.shape[0] self.minimal_filter = max(1, int(self.channel_size * minimal_ratio)) #最小的通道数 self.device = bn.weight.device self._hook = None self.g = nn.Parameter(torch.ones(1, self.channel_size, 1, 1).to(self.device), requires_grad=True)#一个可以用于训练的参数 # 这样后就会生成三个参数self.area\self.score\self.bn_mask self.register_buffer('area', torch.zeros(1).to(self.device)) #即nn.Module.register_buffer,保存一些前向传播会用到的上一次前向传播的结果 self.register_buffer('score', torch.zeros(1, self.channel_size, 1, 1).to(self.device)) self.register_buffer('bn_mask', torch.ones(1, self.channel_size, 1, 1).to(self.device)) #bn_mask就是用来记录该bn的channels层是否被prune了,为0则被prune了,初始化为1 self.extract_from_bn() #将本身的bn的weight、bias和g三个参数重新设置一下 def set_groupid(self, new_id): self.group_id = new_id def extra_repr(self): #即prune后channel数从channel_size变为了bn_mask.sum() return '%d -> %d | ID: %s' % (self.channel_size, int(self.bn_mask.sum()), self.group_id) def extract_from_bn(self): # freeze bn weight with torch.no_grad(): self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10)) #将self.bn.bias / self.bn.weight的值保持在[-10, 10],小于-10的即改为-10,大于10的即改为10 self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1)) self.bn.weight.set_(torch.ones_like(self.bn.weight)) #torch.ones_like(input)相当于torch.ones(input.size()) self.bn.weight.requires_grad = False def reset_score(self): self.score.zero_() def cal_score(self, grad): # used for hook self.score += (grad * self.g).abs() #论文中公式6的计算,计算分数,即变成prune后的网络和以前网络的损失差计算,得到此时设置的每个参数g的分数,分数越小直至0说明该g的channels能删除 def start_collecting_scores(self): if self._hook is not None: self._hook.remove() self._hook = self.g.register_hook(self.cal_score) #后向传播计算出关于这个参数g的gradient后将会调用cal_score计算此时的self.score分数值,排序使用 def stop_collecting_scores(self): if self._hook is not None: self._hook.remove() # 移除register_hook得到的hook self._hook = None def get_score(self, eta=0.0): # use self.bn_mask.sum() to calculate the number of input channel. eta should had been normed # 因为self.bn_mask中的值都是1,大小为torch.ones(1, self.channel_size, 1, 1),所以sum()后的结果为self.channel_size flops_reg = eta * int(self.area[0]) * self.bn_mask.sum() return ((self.score - flops_reg) * self.bn_mask).view(-1) def forward(self, x): x = self.bn(x) * self.g # self.g就是用来排重要性的参数 self.area[0] = x.shape[-1] * x.shape[-2] #长*宽=面积area if self.bn_mask is not None: return x * self.bn_mask return x def melt(self): with torch.no_grad(): mask = self.bn_mask.view(-1) #得到当前prune后的channels数 replacer = nn.BatchNorm2d(int(self.bn_mask.sum())).to(self.bn.weight.device) replacer.running_var.set_(self.bn.running_var[mask != 0]) #BatchNorm2d中的方差 replacer.running_mean.set_(self.bn.running_mean[mask != 0]) #BatchNorm2d中的均值 replacer.weight.set_((self.bn.weight * self.g.view(-1))[mask != 0]) replacer.bias.set_((self.bn.bias * self.g.view(-1))[mask != 0]) return replacer @classmethod def transform(cls, net, minimal_ratio=0.1): r = [] def _inject(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: _inject(modules[k]._modules) if isinstance(modules[k], nn.BatchNorm2d): # 将模型中的nn.BatchNorm2d换成GatedBatchNorm2d,截取后的filter数量>= max(1, int(self.channel_size * minimal_ratio)) modules[k] = GatedBatchNorm2d(modules[k], minimal_ratio) r.append(modules[k]) _inject(net._modules) return r
4)
卷积层的prune:
class Conv2dObserver(Meltable): def __init__(self, conv): super(Conv2dObserver, self).__init__() assert isinstance(conv, nn.Conv2d) self.conv = conv self.in_mask = torch.zeros(conv.in_channels).to('cpu') self.out_mask = torch.zeros(conv.out_channels).to('cpu') self.f_hook = conv.register_forward_hook(self._forward_hook) #该层卷机前向传播是进行的操作 def extra_repr(self): return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum())) def _forward_hook(self, m, _in, _out): x = _in[0] #self.in_mask就是用来记录该channels层是否被prune了,为0则被prune了 self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了 def _backward_hook(self, grad): #后向传播计算出gradient后执行的操作 self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了 new_grad = torch.ones_like(grad) return new_grad def forward(self, x): output = self.conv(x) noise = torch.zeros_like(output).normal_() output = output + noise #? if self.training: output.register_hook(self._backward_hook) return output def melt(self): if self.conv.groups == 1: groups = 1 elif self.conv.groups == self.conv.out_channels: groups = int((self.out_mask != 0).sum()) else: assert False replacer = nn.Conv2d( in_channels = int((self.in_mask != 0).sum()), out_channels = int((self.out_mask != 0).sum()), kernel_size = self.conv.kernel_size, stride = self.conv.stride, padding = self.conv.padding, dilation = self.conv.dilation, groups = groups, bias = (self.conv.bias is not None) ).to(self.conv.weight.device) with torch.no_grad(): if self.conv.groups == 1: replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0]) else: replacer.weight.set_(self.conv.weight[self.out_mask != 0]) if self.conv.bias is not None: replacer.bias.set_(self.conv.bias[self.out_mask != 0]) return replacer @classmethod def transform(cls, net): r = [] def _inject(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: _inject(modules[k]._modules) if isinstance(modules[k], nn.Conv2d): modules[k] = Conv2dObserver(modules[k]) r.append(modules[k]) _inject(net._modules) return r
5)分类最后一层的全连接层怎么变:
class FinalLinearObserver(Meltable): ''' assert was in the last layer. only input was masked ''' def __init__(self, linear): super(FinalLinearObserver, self).__init__() assert isinstance(linear, nn.Linear) self.linear = linear self.in_mask = torch.zeros(linear.weight.shape[1]).to('cpu') self.f_hook = linear.register_forward_hook(self._forward_hook) #该linear层前向传播是进行的函数操作 def extra_repr(self): return '(%d, %d) -> (%d, %d)' % ( int(self.linear.weight.shape[1]), int(self.linear.weight.shape[0]), int((self.in_mask != 0).sum()), int(self.linear.weight.shape[0])) def _forward_hook(self, m, _in, _out): x = _in[0] self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1) #列相加,每一列求和,输入的data中为0的列是因为那个channels被prune了, def forward(self, x): return self.linear(x) def melt(self): # 换成prune后的channels数 with torch.no_grad(): replacer = nn.Linear(int((self.in_mask != 0).sum()), self.linear.weight.shape[0]).to(self.linear.weight.device) replacer.weight.set_(self.linear.weight[:, self.in_mask != 0]) replacer.bias.set_(self.linear.bias) return replacer
这两个函数的作用在于将卷积层和全连接层分别封装成Conv2dObserver和FinalLinearObserver
如Conv2dObserver中就会有in_mask和out_mask两个参数,就是分别在训练的前向传播和后向传播中计算channels轴的和,最后为0则说明该轴已经被prune了,即:
def _forward_hook(self, m, _in, _out): x = _in[0] #self.in_mask就是用来记录该channels层是否被prune了,为0则被prune了 self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了 def _backward_hook(self, grad): #后向传播计算出gradient后执行的操作 self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了 new_grad = torch.ones_like(grad) return new_grad
主要是用在下图标红部分:
即GBN变成剪枝后的BN的同时,卷积层和全连接层根据相邻的GBN层计算得到的in_mask和out_mask两个参数去剪枝对应的filter,令整个网络channels数是能链接起来的
6)gate的loss函数:
def get_gate_sparse_loss(masks, sparse_lambda): def _loss_hook(data, label, logits): loss = 0.0 for gbn in masks: if isinstance(gbn, GatedBatchNorm2d): loss += gbn.g.abs().sum() return sparse_lambda * loss return _loss_hook
这个是计算tock的损失的后半部分,后面看代码它是作为loss_hook的,的确是额外的loss
对应论文中的:
查看resnet56_prune中使用的resnet56的网络结构:
for name, module in pack.net.named_modules(): print(name) print(module)
返回:
DataParallel( (module): ResNet( (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (1): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (2): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (3): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (4): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (5): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (6): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (7): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) (8): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential() ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (shortcut): Sequential( (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ...
7)
剩下的Tick-Tock部分可能结合例子来讲比较好讲Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks - 模型压缩 - 3 - 代码学习,VGG16,Resnet