YOLO Int8量化模块实现
什么是量化?
- 量化是一种将浮点数转换为整数的方法,可以减少计算量和存储空间,提高模型的运行效率和部署能力。
- 量化的过程可以表示为:
\[Q(x) = round(\frac{x}{s}) + z
\]
- 其中,\(x\)是浮点数,\(s\)是缩放因子(scale factor),\(z\)是零点(zero point),\(Q(x)\)是量化后的整数。
- 缩放因子\(s\)和零点\(z\)可以根据不同的量化方法和范围来确定,例如对称量化(symmetric quantization)或非对称量化(asymmetric quantization),无符号整数(unsigned integer)或有符号整数(signed integer)等。
为什么要对YOLO进行Int8量化?
- YOLO是一种流行的目标检测算法,它可以快速地在图像中定位和识别多个目标。
- YOLO的模型通常使用浮点数来表示参数和特征,这样可以保证模型的精度和表达能力,但也带来了较高的计算量和存储空间的需求。
- 为了在移动设备或嵌入式系统上部署YOLO模型,需要对模型进行压缩和优化,以适应有限的资源和性能要求。
- Int8量化是一种常用的压缩和优化方法,它可以将浮点数转换为8位整数,从而减少模型的大小和运行时间,同时尽量保持模型的精度和效果。
如何对YOLO进行Int8量化?
-
对YOLO进行Int8量化的主要步骤如下:
- 确定量化方法和范围,例如使用非对称量化和无符号整数。
- 计算每一层的缩放因子\(s\)和零点\(z\),根据输入数据和权重的分布和范围来确定。
- 对每一层的输入数据和权重进行量化,即将浮点数转换为整数,根据公式\(Q(x) = round(\frac{x}{s}) + z\)来计算。
- 对每一层的输出数据进行反量化,即将整数转换为浮点数,根据公式\(x = s(Q(x) - z)\)来计算。
- 对每一层的卷积操作进行优化,使用整数乘法和移位代替浮点乘法和除法,减少计算量和提高精度。
- 对每一层的激活函数进行调整,使用整数表示激活输出的范围和零点,避免数据溢出或损失。
-
下面以第0层卷积层为例,详细说明如何对YOLO进行Int8量化。
第0层卷积层的Int8量化
-
第0层卷积层的输入数据和权重的缩放因子\(s\)和零点\(z\)分别为:
- 输入数据:\(s_1 = 0.00784314, z_1 = 127\)
- 权重:\(s_2 = 0.00248109, z_2 = 0\)
-
第0层卷积层的输入数据和权重的量化公式分别为:
- 输入数据:\(Q_1(x) = round(\frac{x}{s_1}) + z_1\)
- 权重:\(Q_2(x) = round(\frac{x}{s_2}) + z_2\)
-
第0层卷积层的输出数据的缩放因子\(s\)和零点\(z\)分别为:
- 输出数据:\(s_3 = s_1 \times s_2, z_3 = 86\)
-
第0层卷积层的输出数据的反量化公式为:
- 输出数据:\(x = s_3(Q_3(x) - z_3)\)
-
第0层卷积层的卷积操作的优化方法为:
-
使用整数乘法和移位代替浮点乘法和除法,即:
\[Q_3(x) = round(\frac{Q_1(x) \times Q_2(x)}{s_3}) + z_3 \]等价于:
\[Q_3(x) = \frac{(Q_1(x) \times Q_2(x)) \times M}{2^{15}} >> S + z_3 \] -
其中,\(M\)是一个整数,用来调整缩放因子\(s_3\)的大小,使其接近\(2^{15}\),以提高精度和避免溢出,例如:
\[M = round(\frac{2^{15}}{s_3}) \] -
\(S\)是一个整数,用来表示移位的次数,根据缩放因子\(s_3\)的大小来确定,例如:
\[S = -log_2(s_3) \] -
第0层卷积层的\(M\)和\(S\)分别为:
- \(M = 19290\)
- \(S = -4\)
-
-
第0层卷积层的激活函数的调整方法为:
-
使用整数表示激活输出的范围和零点,例如使用ReLU6作为激活函数,那么激活输出的范围为\([0, 6]\),零点为\(0\),则可以用整数表示为:
\[Q_a(x) = min(max(Q_3(x), 0), \frac{6}{s_a}) + z_a \] -
其中,\(s_a\)是激活输出的缩放因子,可以根据激活输出的范围来确定,例如:
\[s_a = \frac{6}{255} \] -
\(z_a\)是激活输出的零点,可以根据激活函数的类型来确定,例如对于ReLU6来说,零点为\(0\)。
-
-
第0层卷积层的激活函数没有进行调整,因为它还没有用到。但是在后面的层中,它会用到激活输出的缩放因子和零点。
YOLO Int8量化模块的实现过程如下:
- 首先,定义一个
quant_int8
模块,用于将24位有符号整数数据输入转换为8位有符号整数数据输出。该模块的输入参数包括mult(乘法因子),shift(右移位数),zero_point(零点偏移量)。该模块的内部逻辑如下:- 使用
mult_gen_0
模块对data_in和mult进行乘法运算,得到39位有符号整数结果mult_rslt。 - 使用always语句在时钟上升沿或复位信号下降沿对mult_rslt进行右移运算,并加上zero_point,得到8位有符号整数结果data_out。
- 使用
- 然后,定义一个
quant_int8_8ch
模块,用于将8个通道的24位有符号整数数据输入转换为8个通道的8位有符号整数数据输出。该模块的输入参数与quant_int8
模块相同,但是对每个通道都使用了一个quant_int8
模块实例来进行转换。该模块还使用了一个shift_reg
模块来延迟data_in_vld信号,使其与data_out_vld信号对齐。
YOLO Int8量化模块的时序分析如下:
quant_int8
模块的时钟周期为1ns,因此其乘法运算需要1ns完成,右移运算需要1ns完成,加法运算需要1ns完成。因此,quant_int8
模块的总延迟为3ns。quant_int8_8ch
模块的时钟周期也为1ns,因此其对每个通道的转换需要3ns完成。由于各个通道是并行处理的,因此quant_int8_8ch
模块的总延迟也为3ns。- shift_reg模块的时钟周期为1ns,因此其对data_in_vld信号的延迟为5ns。
- 因此,YOLO Int8量化模块的总延迟为3ns + 5ns = 8ns。
YOLO Int8量化模块可能遇到的问题和解决方法如下:
- 问题一:量化过程可能导致精度损失和信息丢失,影响YOLO网络的性能和准确度。
- 解决方法一:选择合适的量化参数(mult,shift,zero_point),使得量化后的数据分布尽可能接近原始数据分布。
- 解决方法二:在训练YOLO网络时,使用量化感知训练(Quantization-aware training)方法,使得网络能够适应量化后的数据,并减少精度损失。
- 问题二:量化过程可能导致溢出或欠流现象,使得数据超出范围或变为零。
- 解决方法一:在量化前,对数据进行范围检测和裁剪,使得数据不超过8位有符号整数的最大值或最小值。
- 解决方法二:在量化后,对数据进行饱和运算(Saturation arithmetic),使得数据不超过8位有符号整数的最大值或最小值。
代码清单:
quant_int8
module quant_int8(
// system signals
input sclk ,
input s_rst_n ,
//
input signed [23:0] data_in ,
//
input [14:0] mult ,
input [ 7:0] shift ,
input [ 7:0] zero_point ,
//
output reg [ 7:0] data_out
);
//========================================================================\
// =========== Define Parameter and Internal signals ===========
//========================================================================/
wire signed [38:0] mult_rslt ;
reg [23:0] shift_rslt ;
//=============================================================================
//************** Main Code **************
//=============================================================================
mult_gen_0 mult_gen_0_inst (
.CLK (sclk ), // input wire CLK
.A (data_in ), // input wire [23 : 0] A
.B (mult ), // input wire [14 : 0] B
.P (mult_rslt ) // output wire [38 : 0] P
);
always @(posedge sclk or negedge s_rst_n) begin
if(s_rst_n == 1'b0) begin
shift_rslt <= 'd0;
data_out <= 'd0;
end
else begin
shift_rslt <= mult_rslt[38:15] >> shift;
data_out <= shift_rslt + zero_point;
end
end
endmodule
quant_int8_8ch
module quant_int8_8ch(
// system signals
input sclk ,
input s_rst_n ,
//
input signed [23:0] ch0_data_in ,
input signed [23:0] ch1_data_in ,
input signed [23:0] ch2_data_in ,
input signed [23:0] ch3_data_in ,
input signed [23:0] ch4_data_in ,
input signed [23:0] ch5_data_in ,
input signed [23:0] ch6_data_in ,
input signed [23:0] ch7_data_in ,
input data_in_vld ,
//
input [14:0] mult ,
input [ 7:0] shift ,
input [ 7:0] zero_point ,
//
output wire [ 7:0] ch0_data_out ,
output wire [ 7:0] ch1_data_out ,
output wire [ 7:0] ch2_data_out ,
output wire [ 7:0] ch3_data_out ,
output wire [ 7:0] ch4_data_out ,
output wire [ 7:0] ch5_data_out ,
output wire [ 7:0] ch6_data_out ,
output wire [ 7:0] ch7_data_out ,
output wire data_out_vld
);
//========================================================================\
// =========== Define Parameter and Internal signals ===========
//========================================================================/
//=============================================================================
//************** Main Code **************
//=============================================================================
shift_reg #(
.DLY_CNT (5 )
)shift_reg_inst(
.sclk (sclk ),
.data_in (data_in_vld ),
.data_out (data_out_vld )
);
quant_int8 ch0_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch0_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch0_data_out )
);
quant_int8 ch1_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch1_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch1_data_out )
);
quant_int8 ch2_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch2_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch2_data_out )
);
quant_int8 ch3_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch3_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch3_data_out )
);
quant_int8 ch4_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch4_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch4_data_out )
);
quant_int8 ch5_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch5_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch5_data_out )
);
quant_int8 ch6_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch6_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch6_data_out )
);
quant_int8 ch7_quant_int8_inst(
// system signals
.sclk (sclk ),
.s_rst_n (s_rst_n ),
//
.data_in (ch7_data_in ),
//
.mult (mult ),
.shift (shift ),
.zero_point (zero_point ),
//
.data_out (ch7_data_out )
);
endmodule