深度学习之模型压缩(剪枝、量化)

随着深度学习的发展,模型变得越来越复杂,随之而来的模型参数也越来越多,对于需要训练的模型硬件要求也越来越高。模型压缩技术就是为了解决模型使用成本的问题。通过提高推理速度,降低模型参数量和运算量。现在主流的模型压缩方法包含两大类:剪枝和量化。模型的剪枝是为了减少参数量和运算量,而量化是为了压缩数据的占用量。

(1)模型的剪枝:

剪枝的思路在工程上非常常见,在学习决策树的时候就有通过剪枝的方法来防止过拟合,同样深度学习模型剪枝就是利用这种思想,来删除收益过低的一些计算成本。

基于深度神经网络的大型预训练模型往往拥有着庞大的参数量, 然后达到SOTA的效果。但是我们参考生物的神经网络, 发现却是依靠大量稀疏的连接来完成复杂的意识活动。仿照生物的稀疏神经网络, 通过将大型网络中的稠密连接变成稀疏的连接, 在训练的过程中,逐步将权重较小的参数置为0,然后把那些权重值为0的去掉,也可以达到SOTA的效果, 就是模型的剪枝方法。

Pytorch的模型剪枝方法

  • 第一种: 对特定网络模块的剪枝(Pruning Model).
  • 第二种: 多参数模块的剪枝(Pruning multiple parameters).
  • 第三种: 全局剪枝(GLobal pruning).
  • 第四种: 用户自定义剪枝(Custom pruning).
# 第一种: 对特定网络模块的剪枝(Pruning Model).

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = LeNet().to(device=device)

module = model.conv1
print(list(module.named_parameters()))

print(list(module.named_buffers()))

# 第一个参数: module, 代表要进行剪枝的特定模块, 之前我们已经制定了module=model.conv1,
#             说明这里要对第一个卷积层执行剪枝.
# 第二个参数: name, 指定要对选中的模块中的哪些参数执行剪枝.
#             这里设定为name="weight", 意味着对连接网络中的weight剪枝, 而不对bias剪枝.
# 第三个参数: amount, 指定要对模型中多大比例的参数执行剪枝.
#             amount是一个介于0.0-1.0的float数值, 或者一个正整数指定剪裁掉多少条连接边.

prune.random_unstructured(module, name="weight", amount=0.3)

print(list(module.named_parameters()))
print(list(module.named_buffers()))

# 模型经历剪枝操作后, 原始的权重矩阵weight参数不见了,
# 变成了weight_orig. 并且刚刚打印为空列表的module.named_buffers(),
# 此时拥有了一个weight_mask参数.

print(module.weight)
# 经过剪枝操作后的模型, 原始的参数存放在了weight_orig中,
# 对应的剪枝矩阵存放在weight_mask中, 而将weight_mask视作掩码张量,
# 再和weight_orig相乘的结果就存放在了weight中.

# 我们可以对模型的任意子结构进行剪枝操作,
# 除了在weight上面剪枝, 还可以对bias进行剪枝.

# 第一个参数: module, 代表剪枝的对象, 此处代表LeNet中的conv1
# 第二个参数: name, 代表剪枝对象中的具体参数, 此处代表偏置量
# 第三个参数: amount, 代表剪枝的数量, 可以设置为0.0-1.0之间表示比例, 也可以用正整数表示剪枝的参数绝对数量
prune.l1_unstructured(module, name="bias", amount=3)

# 再次打印模型参数
print(list(module.named_parameters()))
print('*'*50)
print(list(module.named_buffers()))
print('*'*50)
print(module.bias)
print('*'*50)
print(module._forward_pre_hooks)

# 序列化一个剪枝模型(Serializing a pruned model):
# 对于一个模型来说, 不管是它原始的参数, 拥有的属性值, 还是剪枝的mask buffers参数
# 全部都存储在模型的状态字典中, 即state_dict()中.
# 将模型初始的状态字典打印出来
print(model.state_dict().keys())
print('*'*50)

# 对模型进行剪枝操作, 分别在weight和bias上剪枝
module = model.conv1
prune.random_unstructured(module, name="weight", amount=0.3)
prune.l1_unstructured(module, name="bias", amount=3)

# 再将剪枝后的模型的状态字典打印出来
print(model.state_dict().keys())

# 对模型执行剪枝remove操作.
# 通过module中的参数weight_orig和weight_mask进行剪枝, 本质上属于置零遮掩, 让权重连接失效.
# 具体怎么计算取决于_forward_pre_hooks函数.
# 这个remove是无法undo的, 也就是说一旦执行就是对模型参数的永久改变.

# 打印剪枝后的模型参数
print(list(module.named_parameters()))
print('*'*50)

# 打印剪枝后的模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)

# 打印剪枝后的模型weight属性值
print(module.weight)
print('*'*50)

# 打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)
print('*'*50)

# 执行剪枝永久化操作remove
prune.remove(module, 'weight')
print('*'*50)

# remove后再次打印模型参数
print(list(module.named_parameters()))
print('*'*50)

# remove后再次打印模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)

# remove后再次打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)

# 对模型的weight执行remove操作后, 模型参数集合中只剩下bias_orig了,
# weight_orig消失, 变成了weight, 说明针对weight的剪枝已经永久化生效.
# 对于named_buffers张量打印可以看出, 只剩下bias_mask了,
# 因为针对weight做掩码的weight_mask已经生效完毕, 不再需要保留了.
# 同理, 在_forward_pre_hooks中也只剩下针对bias做剪枝的函数了.

 

# 第二种: 多参数模块的剪枝(Pruning multiple parameters).
model = LeNet().to(device=device)

# 打印初始模型的所有状态字典
print(model.state_dict().keys())
print('*'*50)

# 打印初始模型的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 对于模型进行分模块参数的剪枝
for name, module in model.named_modules():
    # 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name="weight", amount=0.2)
    # 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝
    elif isinstance(module, torch.nn.Linear):
        prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)

# 打印多参数模块剪枝后的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 打印多参数模块剪枝后模型的所有状态字典名称
print(model.state_dict().keys())

# 对比初始化模型的状态字典和剪枝后的状态字典,
# 可以看到所有的weight参数都没有了,
# 变成了weight_orig和weight_mask的组合.
# 初始化的模型named_buffers是空列表,
# 剪枝后拥有了所有参与剪枝的参数层的weight_mask张量.

 

# 第三种: 全局剪枝(GLobal pruning).

# 第一种, 第二种剪枝策略本质上属于局部剪枝(local pruning)
# 更普遍也更通用的剪枝策略是采用全局剪枝(global pruning),
# 比如在整体网络的视角下剪枝掉20%的权重参数,
# 而不是在每一层上都剪枝掉20%的权重参数.
# 采用全局剪枝后, 不同的层被剪掉的百分比不同.

model = LeNet().to(device=device)

# 首先打印初始化模型的状态字典
print(model.state_dict().keys())
print('*'*50)

# 构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'),
            (model.fc3, 'weight'))

# 调用prune中的全局剪枝函数global_unstructured执行剪枝操作, 此处针对整体模型中的20%参数量进行剪枝
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

# 最后打印剪枝后的模型的状态字典
print(model.state_dict().keys())

model = LeNet().to(device=device)

parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'),
            (model.fc3, 'weight'))

prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0))
    / float(model.conv1.weight.nelement())
    ))

print(
    "Sparsity in conv2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv2.weight == 0))
    / float(model.conv2.weight.nelement())
    ))

print(
    "Sparsity in fc1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc1.weight == 0))
    / float(model.fc1.weight.nelement())
    ))

print(
    "Sparsity in fc2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc2.weight == 0))
    / float(model.fc2.weight.nelement())
    ))

print(
    "Sparsity in fc3.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc3.weight == 0))
    / float(model.fc3.weight.nelement())
    ))

print(
    "Global sparsity: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0)
               + torch.sum(model.conv2.weight == 0)
               + torch.sum(model.fc1.weight == 0)
               + torch.sum(model.fc2.weight == 0)
               + torch.sum(model.fc3.weight == 0))
         / float(model.conv1.weight.nelement()
               + model.conv2.weight.nelement()
               + model.fc1.weight.nelement()
               + model.fc2.weight.nelement()
               + model.fc3.weight.nelement())
    ))

# 当采用全局剪枝策略的时候(假定20%比例参数参与剪枝),
# 仅保证模型总体参数量的20%被剪枝掉,
# 具体到每一层的情况则由模型的具体参数分布情况来定.

 

# 第四种: 用户自定义剪枝(Custom pruning).
# 剪枝模型通过继承class BasePruningMethod()来执行剪枝,
# 内部有若干方法: call, apply_mask, apply, prune, remove等等.
# 一般来说, 用户只需要实现__init__, 和compute_mask两个函数即可完成自定义的剪枝规则设定.
import time
# 自定义剪枝方法的类, 一定要继承prune.BasePruningMethod
class myself_pruning_method(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # 内部实现compute_mask函数, 完成程序员自己定义的剪枝规则, 本质上就是如何去mask掉权重参数
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        # 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数量的50%被mask掉
        mask.view(-1)[::2] = 0
        return mask

# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def myself_unstructured_pruning(module, name):
    myself_pruning_method.apply(module, name)
    return module


# 实例化模型类
model = LeNet().to(device=device)

start = time.time()
# 调用自定义剪枝方法的函数, 对model中的第三个全连接层fc3中的偏置bias执行自定义剪枝
myself_unstructured_pruning(model.fc3, name="bias")

# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc3.bias_mask)

# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')

# 打印出来的bias_mask张量, 完全是按照预定义的方式每隔一位遮掩掉一位,
#  0和1交替出现, 后续执行remove操作的时候,
# 原始的bias_orig中的权重就会同样的被每隔一位剪枝掉一位.

(2)模型的量化:

量化就是将这些连续的权值进一步稀疏化、离散化。进行离散化之后,相较于原来的连续稠密的值就可以用离散的值来表示了。例如,现在有256个值,是从0到255的整数,那么可以看出这一组数字从统计上来看熵是非常大的,因为分布非常均匀。你是很难对这样的数字表示进行压缩的,要想表示出它们当中的每一个,你必须用8bit的数据来表示。可是,如果这些数字集中在某些数字周围呢?比如256个值里面有56个是8,100个是7,100个9,情况会有什么不同吗?从直观感觉上了看,熵肯定是要小很多的,因为确定性高了很多。那么我们如果用3bit来表示它的中心位置8,再用2bit表示偏移量——1表示+1,0表示无偏移,-1表示-1。那么数据的存储空间又有很大的节省。原来的256个值,每个是8bit,那么一共需要2048个字节才能把数据全都记下来。而用了新方法后,每个值都可以表示为3bit的中心点和2bit的偏移量的大小,那么就变成了5bit来表示一个数字,一共需要1280的字节就够了。

把所有的权值尝试着聚拢到一起,就是尝试找到多个簇,并找的各簇的中心点,在这个图上示意是找到了4个不同的中心点,然后用2bit的信息来表示中心点的编号。然后每个中心点的具体位置具体在列表中标出来(centroids),就是2.00,1.50,0.00和-1.00这几个值。这样记录中心点的矩阵就会小很多,这个过程就叫量化。相较于剪枝而言,量化更容易推广到不同的网络结构中。以下通过在CIFAR-10数据上进行模型的量化,最终结果如下:

config.py

import torch

init_epoch_lr = [(10, 0.01), (20, 0.001), (20, 0.0001)]
sparisity_list = [50, 60, 70, 80, 90]

finetune_epoch_lr = [
    [(3, 0.01),  (3, 0.001),  (3, 0.0001)],
    [(6, 0.01),  (6, 0.001),  (6, 0.0001)],
    [(9, 0.01),  (9, 0.001),  (9, 0.0001)],
    [(12, 0.01), (12, 0.001), (12, 0.0001)],
    [(20, 0.01), (20, 0.001), (20, 0.0001)]
]

checkpoint = 'checkpoint'

batch_size = 128
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model.py

from torch import nn


class VGG_prunable(nn.Module):
    def __init__(self, cfg):
        super(VGG_prunable, self).__init__()
        self.features = self._make_layers(cfg)
        self.classifier = nn.Linear(cfg[-2], 10)

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [
                    nn.Conv2d(in_channels=in_channels,
                              out_channels=x, kernel_size=3, padding=1),
                    nn.BatchNorm2d(x),
                    nn.ReLU(inplace=True)
                ]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out


def VGG_11_prune(cfg=None):
    if cfg is None:
        cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
    return VGG_prunable(cfg)


if __name__ == '__main__':
    print(VGG_11_prune())

##################################################################################
VGG_prunable(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (17): ReLU(inplace=True)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (20): ReLU(inplace=True)
    (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (22): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (23): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (24): ReLU(inplace=True)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (27): ReLU(inplace=True)
    (28): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (29): AvgPool2d(kernel_size=1, stride=1, padding=0)
  )
  (classifier): Linear(in_features=512, out_features=10, bias=True)
)
##################################################################################

base_train.py

from config import device, checkpoint, init_epoch_lr
from data import trainloader, trainset, testloader, testset
from model import VGG_11_prune
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm


def train_epoch(net, optimizer, crition):
    epoch_loss = 0.0
    epoch_acc = 0.0
    for j, (img, label) in tqdm(enumerate(trainloader)):
        img, label = img.to(device), label.to(device)
        out = net(img)
        optimizer.zero_grad()
        loss = crition(out, label)
        loss.backward()
        optimizer.step()
        pred = torch.argmax(out, dim=1)
        acc = torch.sum(pred == label)
        epoch_loss += loss.item()
        epoch_acc += acc.item()

    epoch_acc /= len(trainset)
    epoch_loss /= len(trainloader)
    print('epoch loss :{:8f} epoch acc :{:8f}'.format(epoch_loss, epoch_acc))
    return epoch_acc, epoch_loss, net


def validation(net, criteron):
    with torch.no_grad():
        test_loss = 0.0
        test_acc = 0.0
        for k, (img, label) in tqdm(enumerate(testloader)):
            img, label = img.to(device), label.to(device)
            out = net(img)
            loss = criteron(out, label)
            pred = torch.argmax(out, dim=1)
            acc = torch.sum(pred == label)
            test_loss += loss.item()
            test_acc += acc.item()
        test_acc /= len(testset)
        test_loss /= len(testloader)
        print('test loss :{:8f} test acc :{:8f}'.format(test_loss, test_acc))
        return test_acc, test_loss


def init_train(net):
    if os.path.exists(os.path.join(checkpoint, 'best_model.pth')):
        save_model = torch.load(os.path.join(checkpoint, 'best_model.pth'))
        net.load_state_dict(save_model['net'])
        if save_model['best_accuracy'] > 0.9:
            print('break init train')
            return
        best_accuracy = save_model['best_accuracy']
        best_loss = save_model['best_loss']
    else:
        best_accuracy = 0.0
        best_loss = 10.0
    writer = SummaryWriter('logs/')
    criteron = torch.nn.CrossEntropyLoss()

    for i, (num_epoch, lr) in enumerate(init_epoch_lr):
        optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9)
        for epoch in range(num_epoch):
            print('epoch: {}'.format(epoch))
            epoch_acc, epoch_loss, net = train_epoch(net, optimizer, criteron)
            writer.add_scalar('epoch_acc', epoch_acc,
                              sum([e[0] for e in init_epoch_lr[:i]])+epoch)
            writer.add_scalar('epoch_loss', epoch_loss,
                              sum([e[0] for e in init_epoch_lr[:i]]) + epoch)

            test_acc, test_loss = validation(net, criteron)
            if test_loss <= best_loss:
                if test_acc >= best_accuracy:
                    best_accuracy = test_acc

                best_loss = test_loss
                best_model_weights = net.state_dict().copy()
                best_model_params = optimizer.state_dict().copy()
                torch.save(
                    {
                        'net': best_model_weights,
                        'optimizer': best_model_params,
                        'best_accuracy': best_accuracy,
                        'best_loss': best_loss
                    },
                    os.path.join(checkpoint, 'best_model.pth')
                )

            writer.add_scalar('test_acc', test_acc,
                              sum([e[0] for e in init_epoch_lr[:i]]) + epoch)
            writer.add_scalar('test_loss', test_loss,
                              sum([e[0] for e in init_epoch_lr[:i]]) + epoch)

    writer.close()
    return net


if __name__ == '__main__':
    net = VGG_11_prune().to(device)
    init_train(net)

训练完成之后,会在checkpoint文件下生成模型

之后对训练好的模型参数进行量化,代码如下:

quantize.py

import torch
import os
from copy import deepcopy
from collections import OrderedDict
import matplotlib.pyplot as plt

from model import VGG_11_prune
from base_train import validation
from config import checkpoint, device

# 量化权重
def signed_quantize(x, bits, bias=None):
    min_val, max_val = x.min(), x.max()
    n = 2.0 ** (bits -1)
    scale = max(abs(min_val), abs(max_val)) / n
    qx = torch.floor(x / scale)
    if bias is not None:
        qb = torch.floor(bias / scale)
        return qx, qb
    else:
        return qx

# 对模型整体进行量化
def scale_quant_model(model, bits):
    net = deepcopy(model)
    params_quant = OrderedDict()
    params_save = OrderedDict()

    for k, v in model.state_dict().items():
        if 'classifier' not in k and 'num_batches' not in k and 'running' not in k:
            if 'weight' in k:
                weight = v
                bias_name = k.replace('weight', 'bias')
                try:
                    bias = model.state_dict()[bias_name]
                    w, b = signed_quantize(weight, bits, bias)
                    params_quant[k] = w
                    params_quant[bias_name] = b
                    if bits > 8 and bits <= 16:
                        params_save[k] = w.short()
                        params_save[bias_name] = b.short()
                    elif bits >1 and bits <= 8:
                        params_save[k] = w.char()
                        params_save[bias_name] = b.char()
                    elif bits == 1:
                        params_save[k] = w.bool()
                        params_save[bias_name] = b.bool()

                except:
                    w = signed_quantize(w, bits)
                    params_quant[k] = w
                    params_save[k] = w.char()

        else:
            params_quant[k] = v
            params_save[k] = v
    net.load_state_dict(params_quant)
    return net, params_save


if __name__ == '__main__':
    pruned = False
    if pruned:
        channels = [17, 'M', 77, 'M', 165, 182, 'M', 338, 337, 'M', 360, 373, 'M']
        net = VGG_11_prune(channels).to(device)
        net.load_state_dict(
            torch.load(
                os.path.join(checkpoint, 'best_retrain_model.pth'))['compressed_net'])
    else:
        net = VGG_11_prune().to(device)
        net.load_state_dict(
            torch.load(
                os.path.join(checkpoint, 'best_model.pth'), map_location=torch.device('cpu')
            )['net']
        )

    validation(net, torch.nn.CrossEntropyLoss())
    accuracy_list = []
    bit_list = [16, 12, 8, 6, 4, 3, 2, 1]
    for bit in bit_list:
        print('{} bit'.format(bit))
        scale_quantized_model, params = scale_quant_model(net, bit)
        print('validation: ', end='\t')
        accuracy, _ = validation(scale_quantized_model, torch.nn.CrossEntropyLoss())
        accuracy_list.append(accuracy)
        torch.save(params,
                   os.path.join(checkpoint, 'pruned_{}_{}_bits.pth'.format(pruned, bit)))

    plt.plot(bit_list, accuracy_list)
    plt.savefig('img/quantize_pruned:{}.jpg'.format(pruned))
    plt.show()
CIFAR-10未剪枝量化准确率
test loss :0.426187 test acc :0.862100
16 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:14,  1.06it/s]
test loss :10474.145823 test acc :0.863400
12 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:15,  1.04it/s]
test loss :659.361133 test acc :0.863300
8 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:14,  1.06it/s]
test loss :48.506328 test acc :0.851800
6 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:14,  1.07it/s]
test loss :44.048244 test acc :0.356100
4 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:15,  1.05it/s]
test loss :5.035617 test acc :0.103500
3 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:14,  1.06it/s]
test loss :2.572487 test acc :0.099700
2 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:14,  1.06it/s]
test loss :2.301575 test acc :0.101000
1 bit
validation: 	[W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
0it [00:00, ?it/s][W ParallelNative.cpp:212] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
79it [01:13,  1.07it/s]
test loss :2.303252 test acc :0.100000

由上图可以看到,将参数量化到int8,模型的精度基本没有发生较大的变化,同时模型的大小也缩小为了原来的1/8,基本很好的完成了模型的压缩效果。

posted @ 2024-03-04 15:46  阿风小子  阅读(820)  评论(0编辑  收藏  举报