pytorch动态量化函数

PyTorch 动态量化 API

PyTorch 提供了丰富的动态量化 API,可以帮助开发者轻松地将模型转换为动态量化模型。主要 API 包括:

torch.quantization.quantize_dynamic:将模型转换为动态量化模型。
torch.quantization.QuantStub:观察模型层的输入和输出分布。
torch.quantization.Observer:收集模型层的统计信息。
torch.quantization.DeQuantStub:将定点结果转换回浮点数。

PyTorch torch.quantization.quantize_dynamic 函数详解

torch.quantization.quantize_dynamic 函数是 PyTorch 提供的用于动态量化模型的主要 API。该函数可以将浮点模型转换为动态量化模型,从而显著降低模型大小和提高推理速度。

函数定义

torch.quantization.quantize_dynamic(
    model: torch.nn.Module,
    qconfig: Dict[Type[torch.nn.Module], Dict],
    dtype: torch.qscheme = torch.qint8
) -> torch.nn.Module

参数说明

  • model: 要转换的浮点模型。
  • qconfig: 指定要量化的模块类型和量化配置。
  • dtype: 指定量化的定点数据类型,可以是 torch.qint8torch.float16

函数返回值

quantize_dynamic 函数返回一个新的动态量化模型,该模型与原始模型具有相同的架构和功能。

函数功能

quantize_dynamic 函数主要执行以下操作:

  • 遍历模型中的每个模块。
  • 对于每个模块,检查其类型是否在 qconfig 中定义。
  • 如果模块类型在 qconfig 中定义,则根据 qconfig 中的配置对该模块进行动态量化。
  • 将量化的模块替换到新的模型中。

动态量化配置

qconfig 参数用于指定要量化的模块类型和量化配置。qconfig 是一个字典,其中键是模块类型,值是量化配置字典。量化配置字典可以包含以下键:

  • ``activation`: 指定激活的量化配置。
  • ``weight`: 指定权重的量化配置。
  • ``qscheme: 指定量化方案,可以是 torch.per_tensortorch.per_channel`。
  • ``dynamic`: 指定是否动态量化。

示例

以下是一个简单的示例,演示如何使用 quantize_dynamic 函数将模型转换为动态量化模型:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

# 定义量化配置
qconfig = {
    nn.Linear: {
        'activation': {'dtype': torch.qint8},
        'weight': {'dtype': torch.qint8}
    }
}

# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
    model,
    qconfig,
    dtype=torch.qint8
)

# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)

在这个示例中,我们定义了一个简单的模型,并使用 qconfig 参数指定了量化配置。qconfig 参数指示 quantize_dynamic 函数对模型中的所有 nn.Linear 模块进行动态量化,并将激活和权重量化为 torch.qint8 格式。

注意事项

在使用 torch.quantization.quantize_dynamic 函数时,需要注意以下几点:

  • 动态量化可能会导致模型精度下降,需要根据具体情况权衡性能和精度。
  • 动态量化目前还不支持所有模型类型和操作。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

PyTorch torch.quantization.QuantStub 模块详解

torch.quantization.QuantStub 模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。

模块定义

class QuantStub(nn.Module):
    r"""Quantize stub module, before calibration, this is same as an observer,.

    It will be swapped as nnq.Quantize in convert .

    Parameters:
        qconfig(Dict): quantization configuration for the tensor, if qconfig is not
          provided, we will use the global qconfig
    """

    def __init__(self, qconfig=None):
        super(QuantStub, self).__init__()
        self.qconfig = qconfig

    def forward(self, x):
        return x

模块属性

  • qconfig: 量化配置字典。

模块方法

  • forward(x): 该方法只是简单地返回输入 x,不做任何处理。

模块功能

QuantStub 模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,QuantStub 模块会被替换为 nnq.Quantize 模块,nnq.Quantize 模块会使用收集的统计信息对输入进行量化。

示例

以下是一个简单的示例,演示如何使用 QuantStub 模块观察模型层的输入和输出分布:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(10, 20),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.ReLU(),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(20, 1)
)

# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)

在这个示例中,我们为模型中的每个层都添加了 QuantStub 模块。QuantStub 模块会观察每个层的输入和输出分布,并收集统计信息。

注意事项

在使用 torch.quantization.QuantStub 模块时,需要注意以下几点:

  • QuantStub 模块只用于观察模型层的输入和输出分布,不进行任何量化操作。
  • QuantStub 模块必须与 torch.quantization.DeQuantStub 模块搭配使用,才能完成动态量化。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

PyTorch torch.quantization.Observer 模块详解

torch.quantization.Observer 模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。

模块定义

class Observer(nn.Module):
    r"""
    Observer module, which observes tensor quantization ranges for dynamic quantization.

    It attaches to the downstream module to observe the output of the module
    and records the min/max values for quantization.

    Parameters:
        dtype(torch.qscheme): quantization dtype, e.g torch.qint8
        quant_scheme(torch.qscheme): quantization scheme, e.g torch.per_tensor or
                                 torch.per_channel
    """

    def __init__(self, dtype=torch.qint8, quant_scheme=torch.per_tensor):
        super(Observer, self).__init__()
        assert dtype in [
            torch.qint8, torch.quint8, torch.bfloat16
        ], 'Only support torch.qint8, torch.quint8, torch.bfloat16 for now'
        self.dtype = dtype
        self.quant_scheme = quant_scheme
        self.qmin = None
        self.qmax = None
        self._called_once = False

    def forward(self, x):
        r"""Calculates the min/max values for quantization.

        Args:
            x(torch.Tensor): The input tensor to observe.

        Returns:
            torch.Tensor: The input tensor.
        """
        if not self._called_once:
            self._called_once = True
            if self.quant_scheme == torch.per_tensor:
                self.qmin = x.min()
                self.qmax = x.max()
            elif self.quant_scheme == torch.per_channel:
                self.qmin = x.data.min(dim=1)[0]
                self.qmax = x.data.max(dim=1)[0]
            else:
                raise NotImplementedError
        return x

模块属性

  • dtype: 量化数据类型,可以是 torch.qint8torch.quint8torch.bfloat16
  • quant_scheme: 量化方案,可以是 torch.per_tensortorch.per_channel
  • qmin: 最小值。
  • qmax: 最大值。

模块方法

  • forward(x): 该方法计算输入 x 的最小值和最大值,并将其存储在 qminqmax 属性中。

模块功能

Observer 模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,Observer 模块收集的统计信息将被用于计算量化参数,例如量化尺度和零点。

示例

以下是一个简单的示例,演示如何使用 Observer 模块观察模型层的输入和输出分布:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
    nn.Linear(10, 20),
    Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
    nn.ReLU(),
    Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
    nn.Linear(20, 1)
)

# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)

在这个示例中,我们为模型中的每个层都添加了 Observer 模块。Observer 模块会观察每个层的输入和输出分布,并收集统计信息。

注意事项

在使用 torch.quantization.Observer 模块时,需要注意以下几点:

  • Observer 模块只用于观察模型层的输入和输出分布,不进行任何量化操作。
  • Observer 模块必须与 torch.quantization.DeQuantStub 模块搭配使用,才能完成动态量化。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

PyTorch torch.quantization.DeQuantStub 模块详解

torch.quantization.DeQuantStub 模块是 PyTorch 提供的用于动态量化模型的反量化模块。该模块可以将定点张量转换为浮点张量,从而恢复模型的精度。

模块定义

class DeQuantStub(nn.Module):
    r"""Dequantize stub module, before calibration, this is same as identity,.

    It will be swapped as nnq.DeQuantize in convert .

    Parameters:
        qconfig(Dict): quantization configuration for the tensor, if qconfig is not
          provided, we will use the global qconfig
    """

    def __init__(self, qconfig=None):
        super(DeQuantStub, self).__init__()
        self.qconfig = qconfig

    def forward(self, x):
        return x

模块属性

  • qconfig: 量化配置字典。

模块方法

  • forward(x): 该方法只是简单地返回输入 x,不做任何处理。

模块功能

DeQuantStub 模块主要用于将定点张量转换为浮点张量。在动态量化过程中,DeQuantStub 模块会被替换为 nnq.DeQuantize 模块,nnq.DeQuantize 模块会将定点张量转换为浮点张量,从而恢复模型的精度。

示例

以下是一个简单的示例,演示如何使用 DeQuantStub 模块将定点张量转换为浮点张量:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(10, 20),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.ReLU(),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(20, 1),
    DeQuantStub(qconfig={'dtype': torch.qint8})
)

# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear: torch.quantization.QuantStub, nn.ReLU: torch.quantization.QuantStub},
    dtype=torch.qint8
)

# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)

在这个示例中,我们在模型的最后添加了一个 DeQuantStub 模块。DeQuantStub 模块会将模型输出的定点张量转换为浮点张量,从而恢复模型的精度。

注意事项

在使用 torch.quantization.DeQuantStub 模块时,需要注意以下几点:

  • DeQuantStub 模块只用于将定点张量转换为浮点张量,不进行任何量化操作。
  • DeQuantStub 模块必须与 torch.quantization.QuantStub 模块搭配使用,才能完成动态量化。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

更多资源

posted @ 2024-06-16 16:30  立体风  阅读(158)  评论(1编辑  收藏  举报