repeat得到的是[b0 b1 b0 b1]现在需要[b0 b0 b1 b1]

pytorch 一个tensor 比如是
[b0
b1
]
用tensor.repeat(2)函数可以得到
[b0
b1
b0
b1
]
我现在想得到
[b0
b0
b1
b1
]
如何优雅的得到?

import torch
c = torch.randint(0, 9, (2, 3))
d = c.repeat(3, 1)
print(f"c={c}\nd={d}")

d = c.unsqueeze(1)
e = d.repeat(1, 3, 1)
print(d.size())
print(e.size())
print(e)
f = e.view(-1, e.size(-1))
print(f)

gpt给出的正确答案

c=tensor([[6, 5, 7],
        [5, 0, 3]])
d=tensor([[6, 5, 7],
        [5, 0, 3],
        [6, 5, 7],
        [5, 0, 3],
        [6, 5, 7],
        [5, 0, 3]])
torch.Size([2, 1, 3])
torch.Size([2, 3, 3])
tensor([[[6, 5, 7],
         [6, 5, 7],
         [6, 5, 7]],

        [[5, 0, 3],
         [5, 0, 3],
         [5, 0, 3]]])
tensor([[6, 5, 7],
        [6, 5, 7],
        [6, 5, 7],
        [5, 0, 3],
        [5, 0, 3],
        [5, 0, 3]])
posted @ 2024-02-19 16:48  无左无右  阅读(4)  评论(0编辑  收藏  举报