pytorch动态量化函数
PyTorch 动态量化 API
PyTorch 提供了丰富的动态量化 API,可以帮助开发者轻松地将模型转换为动态量化模型。主要 API 包括:
torch.quantization.quantize_dynamic:将模型转换为动态量化模型。
torch.quantization.QuantStub:观察模型层的输入和输出分布。
torch.quantization.Observer:收集模型层的统计信息。
torch.quantization.DeQuantStub:将定点结果转换回浮点数。
PyTorch torch.quantization.quantize_dynamic
函数详解
torch.quantization.quantize_dynamic
函数是 PyTorch 提供的用于动态量化模型的主要 API。该函数可以将浮点模型转换为动态量化模型,从而显著降低模型大小和提高推理速度。
函数定义
torch.quantization.quantize_dynamic(
model: torch.nn.Module,
qconfig: Dict[Type[torch.nn.Module], Dict],
dtype: torch.qscheme = torch.qint8
) -> torch.nn.Module
参数说明
model
: 要转换的浮点模型。qconfig
: 指定要量化的模块类型和量化配置。dtype
: 指定量化的定点数据类型,可以是torch.qint8
或torch.float16
。
函数返回值
quantize_dynamic
函数返回一个新的动态量化模型,该模型与原始模型具有相同的架构和功能。
函数功能
quantize_dynamic
函数主要执行以下操作:
- 遍历模型中的每个模块。
- 对于每个模块,检查其类型是否在
qconfig
中定义。 - 如果模块类型在
qconfig
中定义,则根据qconfig
中的配置对该模块进行动态量化。 - 将量化的模块替换到新的模型中。
动态量化配置
qconfig
参数用于指定要量化的模块类型和量化配置。qconfig
是一个字典,其中键是模块类型,值是量化配置字典。量化配置字典可以包含以下键:
- ``activation`: 指定激活的量化配置。
- ``weight`: 指定权重的量化配置。
- ``qscheme
: 指定量化方案,可以是
torch.per_tensor或
torch.per_channel`。 - ``dynamic`: 指定是否动态量化。
示例
以下是一个简单的示例,演示如何使用 quantize_dynamic
函数将模型转换为动态量化模型:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 1)
)
# 定义量化配置
qconfig = {
nn.Linear: {
'activation': {'dtype': torch.qint8},
'weight': {'dtype': torch.qint8}
}
}
# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
model,
qconfig,
dtype=torch.qint8
)
# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)
在这个示例中,我们定义了一个简单的模型,并使用 qconfig
参数指定了量化配置。qconfig
参数指示 quantize_dynamic
函数对模型中的所有 nn.Linear
模块进行动态量化,并将激活和权重量化为 torch.qint8
格式。
注意事项
在使用 torch.quantization.quantize_dynamic
函数时,需要注意以下几点:
- 动态量化可能会导致模型精度下降,需要根据具体情况权衡性能和精度。
- 动态量化目前还不支持所有模型类型和操作。
- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
PyTorch torch.quantization.QuantStub
模块详解
torch.quantization.QuantStub
模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。
模块定义
class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,.
It will be swapped as nnq.Quantize in convert .
Parameters:
qconfig(Dict): quantization configuration for the tensor, if qconfig is not
provided, we will use the global qconfig
"""
def __init__(self, qconfig=None):
super(QuantStub, self).__init__()
self.qconfig = qconfig
def forward(self, x):
return x
模块属性
qconfig
: 量化配置字典。
模块方法
forward(x)
: 该方法只是简单地返回输入x
,不做任何处理。
模块功能
QuantStub
模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,QuantStub
模块会被替换为 nnq.Quantize
模块,nnq.Quantize
模块会使用收集的统计信息对输入进行量化。
示例
以下是一个简单的示例,演示如何使用 QuantStub
模块观察模型层的输入和输出分布:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(10, 20),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.ReLU(),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(20, 1)
)
# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)
在这个示例中,我们为模型中的每个层都添加了 QuantStub
模块。QuantStub
模块会观察每个层的输入和输出分布,并收集统计信息。
注意事项
在使用 torch.quantization.QuantStub
模块时,需要注意以下几点:
QuantStub
模块只用于观察模型层的输入和输出分布,不进行任何量化操作。QuantStub
模块必须与torch.quantization.DeQuantStub
模块搭配使用,才能完成动态量化。- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
PyTorch torch.quantization.Observer
模块详解
torch.quantization.Observer
模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。
模块定义
class Observer(nn.Module):
r"""
Observer module, which observes tensor quantization ranges for dynamic quantization.
It attaches to the downstream module to observe the output of the module
and records the min/max values for quantization.
Parameters:
dtype(torch.qscheme): quantization dtype, e.g torch.qint8
quant_scheme(torch.qscheme): quantization scheme, e.g torch.per_tensor or
torch.per_channel
"""
def __init__(self, dtype=torch.qint8, quant_scheme=torch.per_tensor):
super(Observer, self).__init__()
assert dtype in [
torch.qint8, torch.quint8, torch.bfloat16
], 'Only support torch.qint8, torch.quint8, torch.bfloat16 for now'
self.dtype = dtype
self.quant_scheme = quant_scheme
self.qmin = None
self.qmax = None
self._called_once = False
def forward(self, x):
r"""Calculates the min/max values for quantization.
Args:
x(torch.Tensor): The input tensor to observe.
Returns:
torch.Tensor: The input tensor.
"""
if not self._called_once:
self._called_once = True
if self.quant_scheme == torch.per_tensor:
self.qmin = x.min()
self.qmax = x.max()
elif self.quant_scheme == torch.per_channel:
self.qmin = x.data.min(dim=1)[0]
self.qmax = x.data.max(dim=1)[0]
else:
raise NotImplementedError
return x
模块属性
dtype
: 量化数据类型,可以是torch.qint8
、torch.quint8
或torch.bfloat16
。quant_scheme
: 量化方案,可以是torch.per_tensor
或torch.per_channel
。qmin
: 最小值。qmax
: 最大值。
模块方法
forward(x)
: 该方法计算输入x
的最小值和最大值,并将其存储在qmin
和qmax
属性中。
模块功能
Observer
模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,Observer
模块收集的统计信息将被用于计算量化参数,例如量化尺度和零点。
示例
以下是一个简单的示例,演示如何使用 Observer
模块观察模型层的输入和输出分布:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
nn.Linear(10, 20),
Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
nn.ReLU(),
Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
nn.Linear(20, 1)
)
# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)
在这个示例中,我们为模型中的每个层都添加了 Observer
模块。Observer
模块会观察每个层的输入和输出分布,并收集统计信息。
注意事项
在使用 torch.quantization.Observer
模块时,需要注意以下几点:
Observer
模块只用于观察模型层的输入和输出分布,不进行任何量化操作。Observer
模块必须与torch.quantization.DeQuantStub
模块搭配使用,才能完成动态量化。- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
PyTorch torch.quantization.DeQuantStub
模块详解
torch.quantization.DeQuantStub
模块是 PyTorch 提供的用于动态量化模型的反量化模块。该模块可以将定点张量转换为浮点张量,从而恢复模型的精度。
模块定义
class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,.
It will be swapped as nnq.DeQuantize in convert .
Parameters:
qconfig(Dict): quantization configuration for the tensor, if qconfig is not
provided, we will use the global qconfig
"""
def __init__(self, qconfig=None):
super(DeQuantStub, self).__init__()
self.qconfig = qconfig
def forward(self, x):
return x
模块属性
qconfig
: 量化配置字典。
模块方法
forward(x)
: 该方法只是简单地返回输入x
,不做任何处理。
模块功能
DeQuantStub
模块主要用于将定点张量转换为浮点张量。在动态量化过程中,DeQuantStub
模块会被替换为 nnq.DeQuantize
模块,nnq.DeQuantize
模块会将定点张量转换为浮点张量,从而恢复模型的精度。
示例
以下是一个简单的示例,演示如何使用 DeQuantStub
模块将定点张量转换为浮点张量:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(10, 20),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.ReLU(),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(20, 1),
DeQuantStub(qconfig={'dtype': torch.qint8})
)
# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear: torch.quantization.QuantStub, nn.ReLU: torch.quantization.QuantStub},
dtype=torch.qint8
)
# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)
在这个示例中,我们在模型的最后添加了一个 DeQuantStub
模块。DeQuantStub
模块会将模型输出的定点张量转换为浮点张量,从而恢复模型的精度。
注意事项
在使用 torch.quantization.DeQuantStub
模块时,需要注意以下几点:
DeQuantStub
模块只用于将定点张量转换为浮点张量,不进行任何量化操作。DeQuantStub
模块必须与torch.quantization.QuantStub
模块搭配使用,才能完成动态量化。- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
更多资源
- PyTorch 动态量化文档:https://pytorch.org/
- 动态量化教程:https://blog.csdn.net/lk142500/article/details/138860037
- PyTorch 量化感知训练示例:https://github.com/leimao/PyTorch-Quantization-Aware-Training