pytorch 实现剪枝的思路是 生成一个掩码,然后同时保存 原参数、mask、新参数,如下图
pytorch 剪枝分为 局部剪枝、全局剪枝、自定义剪枝;
局部剪枝 是对 模型内 的部分模块 的 部分参数 进行剪枝,全局剪枝是对 整个模型进行剪枝;
本文旨在记录 pytorch 剪枝模块的用法,首先让我们构建一个模型
import torch from torch import nn import torch.nn.utils.prune as prune import torch.nn.functional as F device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() # 1 input image channel, 6 output channels, 3x3 square conv kernel self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6, 16, 3) self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, int(x.nelement() / x.shape[0])) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = LeNet().to(device=device)
下面对 这个模型进行剪枝
局部剪枝
以修剪 第一层卷积 模块 为例
module = model.conv1 print(list(module.named_parameters())) print(list(module.buffers())) # 修剪是从 模块 中 删除 参数(如 weight),并用 weight_orig 保存该参数 # random_unstructured 是一种裁剪技术,随机非结构化裁剪 prune.random_unstructured(module, name="weight", amount=0.3) # weight bias print(list(module.named_parameters())) # 通过修剪技术会创建一个mask命名为 weight_mask 的模块缓冲区 print(list(module.named_buffers())) # 新的参数保存为模块 的weight属性 print(module.weight) # print(module.bias) print(module._forward_pre_hooks) # OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>)])
named_parameters() 内 存储的对象 除非手动删除,否则在剪枝过程中对其无影响
迭代剪枝
迭代剪枝 是 对 同一模块 进行 多种剪枝,执行逻辑是 顺序执行各剪枝操作
在之前 随机非结构化剪枝 的基础上进行 L1 L2 非结构化剪枝
## 增加一个修剪,看看变化 # l1范数修剪bias中3个最小条目 prune.l1_unstructured(module, name="bias", amount=3) print(module.bias) print(module._forward_pre_hooks) # OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>), # (1, <torch.nn.utils.prune.L1Unstructured object at 0x000002695DE5CEB8>)]) print(list(module.named_parameters())) print(list(module.named_buffers())) ### 迭代修剪 # 一个模块中的同一参数可以被多次修剪,多次修剪会顺序执行 # 如在之前的基础上,对 weight 参数继续修剪 # l2 结构化裁剪,n=2代表l2,dim=0代表在weight的第0轴进行结构化裁剪 prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0) # 查看 weight 参数的 剪枝 操作 for hook in module._forward_pre_hooks.values(): if hook._tensor_name == "weight": # select out the correct hook break print(list(hook)) # [<torch.nn.utils.prune.RandomUnstructured object at 0x0000020AE2A6EC18>, # <torch.nn.utils.prune.LnStructured object at 0x0000020AA872DE80>] print(module.state_dict().keys()) # odict_keys(['weight_orig', 'bias_orig', 'weight_mask', 'bias_mask'])
修剪模型中的多个参数
### 修剪模型中的多个参数 new_model = LeNet() for name, module in new_model.named_modules(): # prune 20% of connections in all 2D-conv layers if isinstance(module, torch.nn.Conv2d): prune.l1_unstructured(module, name='weight', amount=0.2) # prune 40% of connections in all linear layers elif isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=0.4) print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
全局剪枝
以上研究通常被称为“局部”修剪方法,即通过比较每个条目的统计信息(权重,激活度,梯度等)来逐一修剪模型中的张量的做法。
但是,一种常见且可能更强大的技术是通过删除整个模型中最低的 20%的连接,
而不是删除每一层中最低的 20%的连接来修剪模型。
这很可能导致每个层的修剪百分比不同。
让我们看看如何使用torch.nn.utils.prune中的global_unstructured进行操作
model = LeNet() parameters_to_prune = ( (model.conv1, 'weight'), (model.conv2, 'weight'), (model.fc1, 'weight'), (model.fc2, 'weight'), (model.fc3, 'weight'), ) prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, ) # 检查每个修剪参数的稀疏性,该稀疏性不等于每层中的 20%。 但是,全局稀疏度将(大约)为 20%
自定义剪枝
见 参考资料3
训练中剪枝实例
见参考资料1
参考资料:
https://blog.csdn.net/qq_40268672/article/details/108631518 pytorch剪枝实战 训练时剪枝,类似 dropout
https://blog.csdn.net/ssunshining/article/details/125121066 PyTorch--模型剪枝案例
https://www.w3cschool.cn/pytorch/pytorch-rnmi3bti.html PyTorch 修剪教程
https://www.bilibili.com/video/BV147411W7am?spm_id_from=333.337.search-card.all.click&vd_source=f0fc90583fffcc40abb645ed9d20da32 神经网络剪枝 Neural Network Pruning 自定义的剪枝
https://github.com/mepeichun/Efficient-Neural-Network-Bilibili/tree/master/2-Pruning 上面视频的 代码 已下载