Max_Pool模块完善
Max_Pool模块完善
什么是最大池化层(Max Pooling Layer)?
- 最大池化层是一种常用的池化层(Pooling Layer),它的作用是对输入的特征图(Feature Map)进行降维压缩,以加快运算速度,减少参数数量,防止过拟合,提高模型的尺度不变性和旋转不变性 。
- 最大池化层的原理是:在前向传播过程中,对每个特征图的区域(通常是2x2或3x3的窗口),选择其中的最大值作为该区域池化后的值;在反向传播过程中,梯度只通过前向传播时的最大值反向传播,其他位置的梯度为0 。
- 最大池化层可以分为重叠池化(Overlapping Pooling)和非重叠池化(Non-overlapping Pooling),区别在于窗口移动的步长(Stride)是否等于窗口大小 。
- 重叠池化的步长小于窗口大小,例如AlexNet/GoogLeNet系列中采用的3x3窗口,步长为2的重叠池化。
- 非重叠池化的步长等于窗口大小,例如VGG系列中采用的2x2窗口,步长为2的非重叠池化。
- 最大池化层的优点在于它能学习到图像的边缘和纹理结构,同时保留了一定的空间信息 。
最大池化层有以下几个优点:
- 减少模型的参数数量和计算量,从而提高模型的效率和速度。
- 增强模型对输入图像中特征位置变化的鲁棒性,从而提高模型的泛化能力。
- 提取更高层次或更抽象的特征,从而增强模型的表达能力。
目的
- 为了实现YOLO tiny网络中的最大值池化层,需要对Max_Pool模块进行完善,增加对步长等于1的最大值池化的支持。
- 步长等于1的最大值池化是在网络结构中的第11层,步长为1的最大池化层是在每个2x2的窗口内选择最大值作为输出,同时在输入特征图的两条边上填充0,使得输出特征图的尺寸与输入相同。
stride=1最大值池化的工作原理
- stride等于1的最大池化是一种特殊的最大池化层,它的作用是对输入的特征图进行填充(Padding)和平滑(Smoothing),而不改变特征图的尺寸。
- stride等于1的最大池化的原理是:在前向传播过程中,在输入特征图的两条相邻边(通常是上边和左边)填充0,使得特征图尺寸增加一个单位(例如13x13变成14x14);然后对每个特征图区域(通常是2x2窗口),选择其中的最大值作为该区域池化后的值;在反向传播过程中,梯度只通过前向传播时的最大值反向传播,其他位置的梯度为0。
- stride等于1的最大池化可以分为四种情况,根据填充边和移动方向不同而有所区别:
- 上左填充,右下移动:这是Yolo网络中使用的情况,在layer 11处采用了stride等于1,窗口大小为2x2的最大池化。
- 上右填充,左下移动:这种情况与上一种情况类似,只是填充边和移动方向相反。
- 下左填充,右上移动:这种情况与上一种情况类似,只是填充边和移动方向相反。
- 下右填充,左上移动:这种情况与上一种情况类似,只是填充边和移动方向相反。
- stride等于1的最大池化的优点在于它能平滑特征图,减少噪声,增强特征的鲁棒性。
- 例如,如果输入图像为:
D0 | D1 | D2 |
---|---|---|
D3 | D4 | D5 |
D6 | D7 | D8 |
则填充后的图像为:
0 | 0 | 0 | 0 |
---|---|---|---|
0 | D0 | D1 | D2 |
0 | D3 | D4 | D5 |
0 | D6 | D7 | D8 |
则输出图像为:
max(0,D0) | max(D0,D1) | max(D1,D2) |
---|---|---|
max(0,D3) | max(D3,D4) | max(D4,D5) |
max(0,D6) | max(D6,D7) | max(D7,D8) |
实现
时序:
- 输入数据和输出数据的时序控制
- 输入数据由 data_in 和 data_in_valid 信号组成,data_in_valid 信号为高时表示 data_in 信号有效。
- 输出数据由 data_out 和 data_out_valid 信号组成,data_out_valid 信号为高时表示 data_out 信号有效。
- 输入数据和输出数据都是按行顺序传输,每行有 13 个数据。
- 输入数据和输出数据之间有一定延迟,因为需要进行比较和 FIFO 的读写操作。
- 数据打拍和比较逻辑
- 数据打拍是指将输入数据分成两个寄存器,一个寄存器存储当前数据,另一个寄存器存储上一个数据,以便进行相邻两个数据的比较。
- 比较逻辑是指将两个寄存器中的数据进行比较,取最大值作为输出数据。
- 数据打拍和比较逻辑需要根据 stride 的值进行不同的处理。
- 当 stride 等于二时,数据打拍和比较逻辑只在 data_in_valid 信号为高时进行,即每两个数据进行一次打拍和比较。
- 当 stride 等于一时,数据打拍和比较逻辑需要在每个数据到来时进行,即每个数据进行一次打拍和比较。此外,还需要在每行的第一个数据到来之前,将上一个寄存器的值赋为零,以实现填充数据的效果。
- FIFO 的读写控制
- FIFO 是指先进先出的存储器,用于存储上一行的最大值结果,以便与下一行的最大值结果进行比较。
- FIFO 的读写控制需要根据 stride 的值进行不同的处理。
- 当 stride 等于二时,FIFO 的写使能信号由 row_even_flag 和 col_even_flag 信号控制,即每四个数据写入一次 FIFO。FIFO 的读使能信号由 row_even_flag 信号控制,即从第二行开始每两行读取一次 FIFO。
- 当 stride 等于一时,FIFO 的写使能信号由 data_in_valid 信号控制,即每个数据写入一次 FIFO。FIFO 的读使能信号由 row_cnt 信号控制,即从第二行开始每行读取一次 FIFO。此外,还需要在每次池化开始之前,将 FIFO 的数据清空,以避免干扰。
- 行计数器和列计数器
- 行计数器和列计数器用于记录当前处理的是第几行和第几列的数据,以便进行 FIFO 的读写控制和输出数据的时序控制。
- 行计数器和列计数器需要根据 stride 的值进行不同的处理。
- 当 stride 等于二时,行计数器和列计数器都是在 data_in_valid 信号的下降沿进行加一操作,即每两个数据加一次。
- 当 stride 等于一时,行计数器和列计数器都是在 data_in_valid 信号的上升沿进行加一操作,即每个数据加一次。
- stride 的判断和传递
- stride 的判断是指根据配置表格中的 bit13 位来判断当前层的池化步长是多少,bit13 位为零表示步长为二,bit13 位为一表示步长为一。
- stride 的传递是指将配置表格中的 bit13 位作为一个端口输入到池化模块中,并在池化模块中根据该端口来进行不同的处理。
- 填充数据的生成
- 填充数据的生成是指在 stride 等于一的情况下,在输入图像的两条边上添加零像素,以保证输出图像的尺寸和输入图像相同。
- 填充数据的生成可以通过在每行的第一个数据到来之前,将上一个寄存器的值赋为零来实现。这样就相当于在输入图像的上边和左边各添加了一行或一列零像素。
Example:
- 假设有一个\(3*3\)大小的输入图像,其数值如下
1 | 2 | 3 |
---|---|---|
4 | 5 | 6 |
7 | 8 | 9 |
对该输入图像进行\(stride=1\)最大值池化,首先需要对其进行填充,在其上边和左边添加一行或一列0,得到一个\(4*4\)大小的填充后图像:
0 | 0 | 0 | 0 |
---|---|---|---|
0 | 1 | 2 | 3 |
0 | 4 | 5 | 6 |
0 | 7 | 8 | 9 |
然后使用一个\(2*2\)大小的滑动窗口,在每个子区域内选择最大值作为输出,滑动窗口从左上角开始,每次向右或向下移动一个像素,直到覆盖整个填充后图像。我们可以得到一个\(3*3\)大小的输出图像,其数值如下:
max(0,1) = 1 | max(1,2) = 2 | max(2,3) = 3 |
---|---|---|
max(0,4) = 4 | max(4,5) = 5 | max(5,6) = 6 |
max(0,7) = 7 | max(7,8) = 8 | max(8,9) = 9 |
可以看到,输出图像与输入图像尺寸相同,且每个位置的值都是对应子区域的最大值。
代码框架:
module max_pool_stride_1(
input clk, rst, // 时钟和复位信号
input data_in_valid, // 输入数据的有效标志
input [7:0] data_in, // 输入数据,8位宽,每次输入一个数据
output reg data_out_valid, // 输出数据的有效标志
output reg [7:0] data_out // 输出数据,8位宽,每次输出一个数据
);
// 定义一些参数和信号
parameter data_width = 8; // 输入数据的位宽
parameter channel_num = 8; // 输入数据的通道数
parameter row_num = 13; // 输入数据的行数
parameter col_num = 13; // 输入数据的列数
parameter window_size = 2; // 窗口的大小
parameter stride = 1; // 窗口的步长
parameter fifo_depth = 14; // FIFO 缓存的深度
reg [7:0] data_in_ie; // 数据打拍后的输出
reg [7:0] max_data; // 比较器的输出,即相邻两个数据的最大值
reg fifo_write_en; // FIFO 缓存的写使能信号
reg fifo_read_en; // FIFO 缓存的读使能信号
wire [7:0] fifo_read_data; // FIFO 缓存的读数据信号
reg [3:0] row_cnt; // 行计数器,用于记录当前输入数据的行数
reg [3:0] col_cnt; // 列计数器,用于记录当前输入数据的列数
// 实例化 FIFO 缓存模块,使用 Xilinx 提供的 IP 核
FIFO_SRL #(
.DATA_WIDTH(data_width), // 数据位宽
.FIFO_DEPTH(fifo_depth), // FIFO 深度
.ALMOST_EMPTY_OFFSET(1), // 几乎空偏移量,用于产生几乎空标志信号
.ALMOST_FULL_OFFSET(1) // 几乎满偏移量,用于产生几乎满标志信号
) fifo (
.clk(clk), // 时钟信号
.srst(rst), // 同步复位信号
.din(max_data), // 写数据信号,即比较器的输出
.wr_en(fifo_write_en), // 写使能信号
.rd_en(fifo_read_en), // 读使能信号
.dout(fifo_read_data), // 读数据信号,即 FIFO 缓存中存储的上一行相邻两个数据的最大值
.full(), // 满标志信号,本例中不使用
.empty(), // 空标志信号,本例中不使用
.prog_full(), // 几乎满标志信号,本例中不使用
.prog_empty() // 几乎空标志信号,本例中不使用
);
// 使用时序逻辑来控制数据打拍和比较器的工作状态
always @(posedge clk) begin
if (rst) begin
data_in_ie <= 0; // 复位时将数据打拍后的输出清零
max_data <= 0; // 复位时将比较器的输出清零
row_cnt <= 0; // 复位时将行计数器清零
col_cnt <= 0; // 复位时将列计数器清零
end else begin
// 当输入数据有效时,进行数据打拍和比较
if (data_in_valid) begin
// 数据打拍,将当前输入数据和上一个输入数据同时输出
data_in_ie <= data_in;
// 比较器,比较相邻两个数据的大小,并输出较大的数据
if (data_in > data_in_ie) begin
max_data <= data_in;
end else begin
max_data <= data_in_ie;
end
// 行计数器,根据输入数据的有效标志的下降沿来计数
if (data_in_valid == 0 && data_in_ie == 1) begin
row_cnt <= row_cnt + 1;
end
// 列计数器,根据输入数据的有效标志来计数
col_cnt <= col_cnt + 1;
end else begin
// 当输入数据无效时,将数据打拍后的输出清零,保证第一行和第一列的数据只与 0 比较
data_in_ie <= 0;
end
end
end
graph LR
subgraph 输入特征图
A((1)) -- 2x2 窗口 --> B((0))
B -- 2x2 窗口 --> C((3))
D((4)) -- 2x2 窗口 --> E((6))
E -- 2x2 窗口 --> F((8))
G((9)) -- 2x2 窗口 --> H((5))
H -- 2x2 窗口 --> I((7))
end
subgraph 输出特征图
J((4)) --> K((6)) --> L((8))
M((9)) --> N((9)) --> O((8))
P((9)) --> Q((7)) --> R((7))
end
A -.-> J
B -.-> K
C -.-> L
D -.-> M
E -.-> N
F -.-> O
G -.-> P
H -.-> Q
I -.-> R
style B fill:#f9f,stroke:#333,stroke-width:4px;
style C fill:#f9f,stroke:#333,stroke-width:4px;
style E fill:#f9f,stroke:#333,stroke-width:4px;
style F fill:#f9f,stroke:#333,stroke-width:4px;
style K fill:#ff6,stroke:#333,stroke-width:4