矩阵max_pooling 二维矩阵滑动窗口

题目链接

题目大意
给定M×N矩阵,求经过给定size为A×B的最大池化处理后结果
M, N <= 2e3, A <= M, B <= N
直接上二维线段树超时了。
这里用滑动窗口,先对每行使用滑动窗口,再对得到的数组的每列使用滑动窗口即可。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <deque>
using namespace std;
const int N = 2e3 + 5;
int t[N][N];
int ans[N][N];
deque<int> win(N);

// 标准的滑动窗口
void solve1(int n, int len, int *arr, int *ans) {
	win.clear();
    for(int i = 0; i < n; i ++) {
        while(!win.empty() && arr[win.back()] < arr[i])
			win.pop_back();
		while(!win.empty() && win.front() < i - len + 1)
			win.pop_front();
		win.push_back(i);
		if(i + 1 >= len)
			ans[i - len + 1] = arr[win.front()];
    }
}


int main() {
    ios::sync_with_stdio(false);
	int m, n, a, b;
    cin >> m >> n >> a >> b;
    for(int i = 0; i < m; i ++)
        for(int j = 0; j < n; j ++)
            cin >> t[i][j];
	for(int i = 0; i < m; i ++)
		solve1(n, b, t[i], ans[i]);

    // 遍历每列
    // 第二次用作滑动窗口的数组是经过第一次处理后的
	int len = a;
	for(int i = 0; i < n; i ++) {
		win.clear();
		for(int j = 0; j < m; j ++) {
			while(!win.empty() && ans[win.back()][i] < ans[j][i])
				win.pop_back();
			while(!win.empty() && win.front() < j - len + 1)
				win.pop_front();
			win.push_back(j);
			if(j + 1 >= len)
				ans[j - len + 1][i] = ans[win.front()][i];
		}
	}

	for(int i = 0; i < m - a + 1; i ++)
		for(int j = 0; j < n - b + 1; j ++)
			cout << ans[i][j] << (j == n - b ? '\n' : ' ');
	return 0;
}
```
`timescale 1ns / 1ps module tb_pool_module; // 输入信号 reg clk; reg rst; reg en; reg pool_mode; reg stride; reg [7:0] data; reg data_vld; // 输出信号 wire [7:0] O_data; wire O_data_vld; // 时钟参数 parameter CLK_PERIOD = 10; // 100MHz // 实例化被测模块 pool_module dut ( .clk(clk), .rst(rst), .en(en), .pool_mode(pool_mode), .stride(stride), .data(data), .data_vld(data_vld), .O_data(O_data), .O_data_vld(O_data_vld) ); // 时钟生成 always #(CLK_PERIOD/2) clk = ~clk; // 测试序列 initial begin // 初始化信号 initialize(); // 测试1: 步长1的最大池化 $display(&quot;=== 测试1: 步长1的最大池化 ===&quot;); test_pooling(1'b0, 1'b0); // 测试2: 步长1的平均池化 $display(&quot;=== 测试2: 步长1的平均池化 ===&quot;); test_pooling(1'b0, 1'b1); // 测试3: 步长2的最大池化 $display(&quot;=== 测试3: 步长2的最大池化 ===&quot;); test_pooling(1'b1, 1'b0); // 测试4: 步长2的平均池化 $display(&quot;=== 测试4: 步长2的平均池化 ===&quot;); test_pooling(1'b1, 1'b1); // 完成测试 #100; $display(&quot;=== 所有测试完成 ===&quot;); $finish; end // 初始化任务 task initialize; begin clk = 0; rst = 1; en = 0; pool_mode = 0; stride = 0; data = 0; data_vld = 0; // 保持复位一段时间 #(CLK_PERIOD * 2); rst = 0; #(CLK_PERIOD * 2); end endtask // 池化测试任务 task test_pooling; input test_stride; input test_pool_mode; begin integer i, j; reg [7:0] test_matrix [0:3][0:3]; reg [7:0] expected_output [0:2][0:2]; integer output_count; // 设置测试模式 stride = test_stride; pool_mode = test_pool_mode; en = 1; // 生成测试数据矩阵 (4x4) $display(&quot;输入矩阵:&quot;); for (i = 0; i < 4; i = i + 1) begin for (j = 0; j < 4; j = j + 1) begin test_matrix[i][j] = i * 4 + j + 1; $write(&quot;%2d &quot;, test_matrix[i][j]); end $display(&quot;&quot;); end // 发送数据到池化模块 $display(&quot;发送数据...&quot;); for (i = 0; i < 4; i = i + 1) begin for (j = 0; j < 4; j = j + 1) begin @(posedge clk); data = test_matrix[i][j]; data_vld = 1; end end @(posedge clk); data_vld = 0; // 计算期望输出 if (test_stride == 0) begin // 步长1,输出3x3 $display(&quot;期望输出 (3x3):&quot;); for (i = 0; i < 3; i = i + 1) begin for (j = 0; j < 3; j = j + 1) begin if (test_pool_mode == 0) begin // 最大池化 expected_output[i][j] = max4( test_matrix[i][j], test_matrix[i][j+1], test_matrix[i+1][j], test_matrix[i+1][j+1] ); end else begin // 平均池化 expected_output[i][j] = avg4( test_matrix[i][j], test_matrix[i][j+1], test_matrix[i+1][j], test_matrix[i+1][j+1] ); end $write(&quot;%2d &quot;, expected_output[i][j]); end $display(&quot;&quot;); end end else begin // 步长2,输出2x2 $display(&quot;期望输出 (2x2):&quot;); for (i = 0; i < 2; i = i + 1) begin for (j = 0; j < 2; j = j + 1) begin if (test_pool_mode == 0) begin // 最大池化 expected_output[i][j] = max4( test_matrix[i*2][j*2], test_matrix[i*2][j*2+1], test_matrix[i*2+1][j*2], test_matrix[i*2+1][j*2+1] ); end else begin // 平均池化 expected_output[i][j] = avg4( test_matrix[i*2][j*2], test_matrix[i*2][j*2+1], test_matrix[i*2+1][j*2], test_matrix[i*2+1][j*2+1] ); end $write(&quot;%2d &quot;, expected_output[i][j]); end $display(&quot;&quot;); end end // 等待并检查输出 $display(&quot;实际输出:&quot;); output_count = 0; if (test_stride == 0) begin // 等待9个输出 (3x3) while (output_count < 9) begin @(posedge clk); if (O_data_vld) begin $write(&quot;%2d &quot;, O_data); output_count = output_count + 1; if (output_count % 3 == 0) $display(&quot;&quot;); end end end else begin // 等待4个输出 (2x2) while (output_count < 4) begin @(posedge clk); if (O_data_vld) begin $write(&quot;%2d &quot;, O_data); output_count = output_count + 1; if (output_count % 2 == 0) $display(&quot;&quot;); end end end // 测试间隔 #(CLK_PERIOD * 10); end endtask // 辅助函数:计算4个数的最大值 function [7:0] max4; input [7:0] a, b, c, d; reg [7:0] max_val; begin max_val = a; if (b > max_val) max_val = b; if (c > max_val) max_val = c; if (d > max_val) max_val = d; max4 = max_val; end endfunction // 辅助函数:计算4个数的平均值 function [7:0] avg4; input [7:0] a, b, c, d; reg [9:0] sum; begin sum = a + b + c + d; avg4 = sum >> 2; end endfunction // 波形保存 initial begin $dumpfile(&quot;tb_pool_module.vcd&quot;); $dumpvars(0, tb_pool_module); end // 超时保护 initial begin #1000000; $display(&quot;错误: 仿真超时!&quot;); $finish; end endmodule池化层测试代码纠错
最新发布
11-14
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值