pytorch如何批量reshape,如何每batch_size进行reshape

假设我有一个tensor,它的batch_size是2:

tensor = torch.randn([2, 6])
print(tensor.shape)

输出是

torch.Size([2, 6])

其中tensor.shape[0]代表tensor的batch_size
如果我要把其中每个Batch的数据从6转换成[2,3],怎么写?循环遍历tensor然后循环内用reshape吗?不!
看下面的操作,很简单:

tensor = torch.randn([2, 6])
    print(tensor)
    tensor = tensor.reshape(tensor.shape[0], 2, 3)  # 将每个批次的数据转换成2,3的形状
    print(tensor)
    tensor = tensor.reshape(tensor.shape[0], 6)  # 恢复原来的形状
    print(tensor)

输出是:

tensor([[-0.7920, -0.7887, -0.7362,  0.2238,  0.3442,  1.5486],
        [ 1.7589, -0.3414,  0.4499, -0.0228,  0.4032,  0.3730]])
tensor([[[-0.7920, -0.7887, -0.7362],
         [ 0.2238,  0.3442,  1.5486]],

        [[ 1.7589, -0.3414,  0.4499],
         [-0.0228,  0.4032,  0.3730]]])
tensor([[-0.7920, -0.7887, -0.7362,  0.2238,  0.3442,  1.5486],
        [ 1.7589, -0.3414,  0.4499, -0.0228,  0.4032,  0.3730]])

Process finished with exit code 0

但是要注意!需要改变形状的tensor里面的东西要符合要求!数量不够会报错!

posted @ 2022-09-18 10:11  猪猪猪猪侠  阅读(368)  评论(0编辑  收藏  举报