pytorch如何批量reshape,如何每batch_size进行reshape
假设我有一个tensor,它的batch_size是2:
tensor = torch.randn([2, 6])
print(tensor.shape)
输出是
torch.Size([2, 6])
其中tensor.shape[0]
代表tensor的batch_size
如果我要把其中每个Batch的数据从6转换成[2,3],怎么写?循环遍历tensor然后循环内用reshape吗?不!
看下面的操作,很简单:
tensor = torch.randn([2, 6])
print(tensor)
tensor = tensor.reshape(tensor.shape[0], 2, 3) # 将每个批次的数据转换成2,3的形状
print(tensor)
tensor = tensor.reshape(tensor.shape[0], 6) # 恢复原来的形状
print(tensor)
输出是:
tensor([[-0.7920, -0.7887, -0.7362, 0.2238, 0.3442, 1.5486],
[ 1.7589, -0.3414, 0.4499, -0.0228, 0.4032, 0.3730]])
tensor([[[-0.7920, -0.7887, -0.7362],
[ 0.2238, 0.3442, 1.5486]],
[[ 1.7589, -0.3414, 0.4499],
[-0.0228, 0.4032, 0.3730]]])
tensor([[-0.7920, -0.7887, -0.7362, 0.2238, 0.3442, 1.5486],
[ 1.7589, -0.3414, 0.4499, -0.0228, 0.4032, 0.3730]])
Process finished with exit code 0
但是要注意!需要改变形状的tensor里面的东西要符合要求!数量不够会报错!