03常用pytorch剪枝工具

常用剪枝工具

pytorch官方案例

import torch.nn.utils.prune as prune

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
print(torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = LeNet().to(device=device)
module = model.conv1
prune.random_structurd(module, name="weight", amount=0.3, dim=1)

#对同一层进行连续不同的剪枝
prune.l1_unstructured(module, name="weight", amount=3)
prune.l1_unstructured(module, name="bias", amount=3)
prune.ln_structured(module, name="bias", amount=0.5, n=3, dim=0)

序列化剪枝后的模型

在PyTorch中,named_buffers()是一个模型的方法,它返回一个迭代器,这个迭代器包含了模型中所有持久化的缓冲区。在每次迭代中,它返回一个包含缓冲区名(name)和缓冲区的张量(tensor)的元组。

在神经网络中,有些数据虽然不是模型参数(也就是不会在反向传播中被更新),但是这些数据在前向传播过程中是需要的,这些数据就被称为缓冲区(buffer)。缓冲区通常用于存储不参与梯度计算,但需要在训练过程中持久化的数据。例如,批归一化(Batch Normalization)层中的运行平均值和运行方差就是存储在缓冲区中的。

对于剪枝操作来说,剪枝的掩码通常会被保存为一个缓冲区。这个掩码的作用是在前向传播过程中把被剪枝的权重(也就是被设为0的权重)从计算中排除出去。

所以,named_buffers()函数就是用来获取模型中所有缓冲区的名称和对应的数据。这在进行剪枝操作时,可以用来检查剪枝的掩码是否已经被正确地添加到模型中。

#state_dict()是一个PyTorch模型的方法,它返回一个字典,其中包含了模型的所有参数,包括权重和偏置。字典的键是参数的名称,值是参数的值。这个字典可以用于保存和加载模型的参数。
#keys()是Python字典的一个方法,它返回字典的所有键的列表。
#所以,model.state_dict().keys()返回的是一个包含模型中所有参数名称的列表。weight和bias
print(model.state_dict().keys())

new_model = LeNet()
#这行代码开始遍历模型中的所有模块(或层)。named_modules()函数返回一个迭代器,每次迭代返回一个包含模块名(name)和模块实例(module)的元组。
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

global pruning

model = LeNet()
#第一个元素是model,第二个元素是这个model里哪一些参数要被剪掉
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)
#进行全局无结构剪枝
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

"Sparsity"(稀疏性)是一个数学概念,用于描述一个矩阵中零元素的比例。在深度学习中,稀疏性通常用来描述模型权重矩阵中零值的比例。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

自定义pruning functions

下面是每隔一个就进行一次非结构化剪枝

自定义剪枝pytorch官方教程: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#:~:text=Global sparsity%3A 20.00%25-,Extending torch.nn.utils.prune with custom pruning,-functions

pytorch源码参考: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/prune.py#:~:text=%40abstractmethod,method recipe.

#该类是prune.BasePruningMethod的子类
class ImplEveryOtherPruningMethod(prune.BasePruningMethod):
    #定义剪枝类型
    PRUNING_TYPE = 'unstructured'
    #重写了基类中的抽象方法compute_mask。该方法接收两个参数,一个是待剪枝的张量t,另一个是默认的掩码default_mask。
    def compute_mask(self, t, default_mask):
        #创建一个default_mask的副本,这是为了避免改变原始的default_mask。
        mask = default_mask.clone()
        #这个操作首先将掩码的形状改为一维mask.view(-1),然后选择索引为偶数的所有元素[::2],将它们设置为0。这样就达到了每隔一个元素剪枝的效果。
        mask.view(-1)[::2] = 0
        return mask
def Ieveryother_unstructured_prune(module, name):
    #生成一个想要的mask,并且apply到module的元素上
    ImplEveryOtherPruningMethod.apply(module, name) 
    return module
model = LeNet()
Ieveryother_unstructured_prune(model.fc3, name='bias')

print(model.fc3.bias_mask)
posted @ 2023-07-01 16:40  DemonSlayer  阅读(101)  评论(0编辑  收藏  举报