代码:
# -*- coding: utf-8 -*- import numpy as np def nchw_to_nc1hwc0(data, batch_size, num_channels, height, width, block_size): assert data.shape == (batch_size, num_channels, height, width) c0 = block_size c1 = (num_channels + c0 - 1) // c0 nc1hwc0_data = np.zeros((batch_size, c1, height, width, c0), dtype=data.dtype) for b in range(batch_size): for i in range(num_channels): for j in range(height): for k in range(width): nc1hwc0_data[b, i//c0, j, k, i%c0] = data[b, i, j, k] return nc1hwc0_data def nc1hwc0_to_nchw(data, batch_size, num_channels, height, width, block_size): assert data.shape == (batch_size, (num_channels + block_size - 1) // block_size, height, width, block_size) c0 = block_size c1 = (num_channels + c0 - 1) // c0 nchw_data = np.zeros((batch_size, num_channels, height, width), dtype=data.dtype) for b in range(batch_size): for i in range(num_channels): for j in range(height): for k in range(width): nchw_data[b, i, j, k] = data[b, i//c0, j, k, i%c0] return nchw_data def nchw_to_nc1hwc0_1(data, batch_size, num_channels, height, width, block_size): assert data.shape == (batch_size, num_channels, height, width) c0 = block_size c1 = (num_channels + c0 - 1) // c0 nc1hwc0_data = np.zeros((batch_size, c1, height, width, c0), dtype=data.dtype) for b in range(batch_size): for i in range(num_channels): c1_idx = i // c0 c0_idx = i % c0 nc1hwc0_data[b, c1_idx, :, :, c0_idx] = data[b, i, :, :] return nc1hwc0_data def nc1hwc0_to_nchw_1(data, batch_size, num_channels, height, width, block_size): assert data.shape == (batch_size, (num_channels + block_size - 1) // block_size, height, width, block_size) c0 = block_size c1 = data.shape[1] nchw_data = np.zeros((batch_size, num_channels, height, width), dtype=data.dtype) for b in range(batch_size): for i in range(num_channels): c1_idx = i // c0 c0_idx = i % c0 nchw_data[b, i, :, :] = data[b, c1_idx, :, :, c0_idx] return nchw_data batch_size = 6 num_channels = 11 # 可以设置为任何正整数 height = 7 width = 11 block_size = 16 data = np.random.rand(batch_size, num_channels, height, width) nc1hwc0_data = nchw_to_nc1hwc0(data, batch_size, num_channels, height, width, block_size) nc1hwc0_data_1 = nchw_to_nc1hwc0_1(data, batch_size, num_channels, height, width, block_size) print(nc1hwc0_data.shape) # (2, 3, 8, 8, 16) print(nc1hwc0_data_1.shape) nchw_data = nc1hwc0_to_nchw(nc1hwc0_data, batch_size, num_channels, height, width, block_size) nchw_data_1 = nc1hwc0_to_nchw(nc1hwc0_data, batch_size, num_channels, height, width, block_size) print(nchw_data.shape) # (2, 33, 8, 8) print(nchw_data_1.shape) assert np.allclose(data, nchw_data) # 验证转换后得到的数据与原始数据相同 assert np.allclose(data, nchw_data_1) assert np.allclose(nchw_data, nchw_data_1)