Max_Pool模块实现
Max_Pool模块实现
-
- Max_Pool模块是一种**池化(pooling)**操作,用于对输入的特征图(feature map)进行降采样(downsampling),从而减少参数数量,提高计算效率,防止过拟合。 - Max_Pool模块的原理是在输入的特征图上滑动一个固定大小的窗口(kernel),每次取窗口内的最大值作为输出的一个元素。窗口每次移动的距离称为步长(stride)。 - Max_Pool模块可以有不同的参数设置,如窗口大小、步长、填充(padding)等,影响输出的特征图的形状。 - Max_Pool模块通常接在卷积层(convolutional layer)后面,用于提取输入特征图中的最显著的特征,忽略一些细节信息。
-
最大值池化:一种降低图像数据维度的方法,可以提高计算效率和抽象能力
- 原理:在一个固定大小的窗口内,选取窗口内的最大值作为输出
- 参数:窗口大小为(2x2)、步长(2)、无填充
- 示例:
- 输入特征图:$$\begin{bmatrix}1 & 2 & 3 & 4 \ 5 & 6 & 7 & 8 \ 9 & 10 & 11 & 12 \ 13 & 14 & 15 & 16 \end{bmatrix}$$
- 输出特征图:$$\begin{bmatrix}6 & 8 \ 14 & 16 \end{bmatrix}$$
-
常规方案:先构造一个2*2的矩阵窗口,再分别比较窗口内的四个值,最后得到最大值
-
- 需要调用FIFO,存储图像数据,同步输出同列数据
- 需要定义两个寄存器,构造2*2的矩阵
- FIFO的深度至少是图像一行的数据量,例如416
-
优点:逻辑简单,易于理解
-
缺点:需要调用FIFO存储图像数据,FIFO深度至少为一行图像数据量,占用较多存储资源!
-
-
优化方案:先比较相邻两个数据的最大值,再将结果写入FIFO,然后再比较FIFO读出的数据和当前数据的最大值,得到最终结果
先比较相邻两个数据的最大值,再比较两个最大值的最大值,不需要构造矩阵窗口
- 优点:存储资源占用较少,FIFO深度为图像一行数据量的一半,不需要额外的寄存器
- 缺点:时序要求更高,需要在每个数据到来时进行比较和写入操作
- 在图像数据过来的时候,先在内部定义一个计算器,对数据做一个打拍
- 先比较一行里面相邻的两个数据的最大值,然后把结果写到FIFO里面
- 当第二行数据过来的时候,再比较一行里面相邻的两个数据的最大值,然后从FIFO里面读出第一行的最大值,和当前的最大值再做一次比较,得到最终的池化结果
- 这样可以节省FIFO的深度一半,只需要图像一行数据量的一半,例如第一层是208
- 时序上面要求稍微高一点,需要在每个数据到来之后做比较和写入或读出操作
- 示例
- 假设输入图像数据如下:
00 01 02 03 10 11 12 13 20 21 22 23 30 31 32 33 - 那么输出图像数据如下:
11 13 31 33 - 具体过程如下:
- 当第一行数据00,01,02,03进入模块时,先用一个寄存器R1对数据进行缓存和打拍(Shift Register)(即延迟一个时钟周期)为00,01,02,03
- 比较器(Comparator)比较相邻两个数据00和01,得到最大值01,并通过
fifo_wr_data
写入FIFO - 比较器比较相邻两个数据02和03,得到最大值03,并通过
fifo_wr_data
写入FIFO - 此时FIFO中有两个数据01和03,深度为一行的一半(208)
- 当第二行数据10,11,12,13进入模块时,R1同样进行缓存、打拍和比较为10,11,12,13
- 比较器比较相邻两个数据10和11,得到最大值11,并暂存为
Max_Data
- FIFO读出第一行的第一个最大值01,并和
Max_Data
比较,得到更大的值\(11\),并输出为Out_Data
(pool_data
) - 比较器比较相邻两个数据12和13,得到最大值13,并暂存为
Max_Data
- FIFO读出第一行的第二个最大值03,并和
Max_Data
比较,得到更大的值\(13\),并输出为Out_Data
(pool_data
) - 此时输出的第一行数据为\(11\)和\(13\)
- 以此类推,可以得到输出的第二行数据为\(31\)和\(33\)
- 时序图:
-
最大值池化的代码实现
- 定义输入输出端口和内部信号
module max_pool( input clk, input rst_n, input data_in_valid, input [7:0] data_in, output reg data_out_valid, output reg [7:0] data_out ); reg [7:0] data_in_r1; reg [7:0] fifo_write_data; reg fifo_write_en; reg [7:0] max_data; reg [7:0] fifo_read_data; reg fifo_read_en; reg row_even_flag; reg data_in_value_r1;
- 对输入数据做打拍和比较,得到fifo_write_data和fifo_write_en
always @(posedge clk or negedge rst_n) begin if (~rst_n) begin data_in_r1 <= 8'h00; end else begin if (data_in_valid) begin data_in_r1 <= data_in; end else begin data_in_r1 <= 8'h00; end end end always @(*) begin fifo_write_data = 8'h00; if (data_in_r1 >= data_in) begin fifo_write_data = data_in_r1; end else begin fifo_write_data = data_in; end end assign fifo_write_en = ~row_even_flag & data_in_value_r1 & data_in_valid;
- 调用FIFO模块,存储每行相邻两个数据的最大值,并读出第一行的最大值
fifo_256x8 fifo_inst( .clk(clk), .rst_n(rst_n), .din(fifo_write_data), .wr_en(fifo_write_en), .rd_en(fifo_read_en), .dout(fifo_read_data) );
- 比较第一行和第二行的最大值,得到data_out和data_out_valid
always @(*) begin data_out = 8'h00; if (max_data >= fifo_read_data) begin data_out = max_data; end else begin data_out = fifo_read_data; end end assign data_out_valid = fifo_read_en;
- 定义奇偶数行的标志信号和上升沿和下降沿的标志信号,控制时序逻辑
always @(posedge clk or negedge rst_n) begin if (~rst_n) begin row_even_flag <= 1'b0; end else begin if (data_in_value_r1 & ~data_in_valid) begin row_even_flag <= ~row_even_flag; end end end always @(posedge clk or negedge rst_n) begin if (~rst_n) begin data_in_value_r1 <= 1'b0; end else begin if (data_in_valid) begin data_in_value_r1 <= data_in_valid; end else begin data_in_value_r1 <= 1'b0; end end end
max_pool_ch0
单通道最大值池化模块主要代码:
assign max_data = fifo_wr_data; always @(posedge sclk) begin data_in_r1 <= data_in; end always @(posedge sclk or negedge s_rst_n) begin if(s_rst_n == 1'b0) fifo_wr_data <= 'd0; else if(data_in_r1 >= data_in) fifo_wr_data <= data_in_r1; else fifo_wr_data <= data_in; end always @(posedge sclk or negedge s_rst_n) begin if(s_rst_n == 1'b0) data_out <= 'd0; else if(max_data >= fifo_rd_data) data_out <= max_data; else data_out <= fifo_rd_data; end pool_fifo_ip pool_fifo_ip ( .clk (sclk ), // input wire clk .srst (~s_rst_n ), // input wire srst .din (fifo_wr_data ), // input wire [7 : 0] din .wr_en (fifo_wr_en ), // input wire wr_en .rd_en (fifo_rd_en ), // input wire rd_en .dout (fifo_rd_data ), // output wire [7 : 0] dout .full ( ), // output wire full .empty ( ), // output wire empty .data_count ( ) // output wire [8 : 0] data_count );
max_pool_8ch
8通道最大值池化模块主要代码:
- 将单通道模块复制8份(max_pool_ch0 ~ max_pool_ch7)
- 将输入输出数据合并为8位宽(data_in, data_out)
- 将FIFO读写使能信号统一在外部定义(fifo_wr_en, fifo_rd_en)
always @(posedge sclk) begin data_in_vld_r1 <= data_in_vld; data_out_vld <= fifo_rd_en; end always @(posedge sclk or negedge s_rst_n) begin if(s_rst_n == 1'b0) col_even_flag <= 1'b0; else if(data_in_vld == 1'b1) col_even_flag <= ~col_even_flag; else col_even_flag <= 1'b0; end always @(posedge sclk or negedge s_rst_n) begin if(s_rst_n == 1'b0) fifo_wr_en <= 1'b0; else if(row_even_flag == 1'b1 && col_even_flag == 1'b1) fifo_wr_en <= 1'b1; else fifo_wr_en <= 1'b0; end always @(posedge sclk or negedge s_rst_n) begin if(s_rst_n == 1'b0) fifo_rd_en <= 1'b0; else if(row_even_flag == 1'b0 && col_even_flag == 1'b1) fifo_rd_en <= 1'b1; else fifo_rd_en <= 1'b0; end always @(posedge sclk or negedge s_rst_n) begin if(s_rst_n == 1'b0) row_even_flag <= 1'b0; else if(data_in_vld == 1'b1 && data_in_vld_r1 == 1'b0) row_even_flag <= ~row_even_flag; end max_pool ch0_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch0_data_in ), .data_out (ch0_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) ); max_pool ch1_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch1_data_in ), .data_out (ch1_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) ); max_pool ch2_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch2_data_in ), .data_out (ch2_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) ); max_pool ch3_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch3_data_in ), .data_out (ch3_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) ); max_pool ch4_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch4_data_in ), .data_out (ch4_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) ); max_pool ch5_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch5_data_in ), .data_out (ch5_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) ); max_pool ch6_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch6_data_in ), .data_out (ch6_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) ); max_pool ch7_max_pool_inst( // system signals .sclk (sclk ), .s_rst_n (s_rst_n ), // .data_in (ch7_data_in ), .data_out (ch7_data_out ), .fifo_wr_en (fifo_wr_en ), .fifo_rd_en (fifo_rd_en ) );