torch.repeat 和 torch.repeat_interleave
** 结论
torch.repeat: 输入张量的从后往前的后面维度对应按照repeat中大小进行repeat操作(所以 输入张量维度>= repeat维度)。 假设输入张量为(a,b,c),repeat(x,y),则为b维度repeat x倍,c维度repeat y倍;最终输出维度为(a, bx, cy)
torch.repeat_interleave: 输入张量按照指定维度进行扩展,假设输入张量为(2,2),torch.repeat_interleave(y, 3, dim=1), 原输入张量大小为(2,2),则在维度1扩展3倍,最终为(2,6)。如果没有指定dim,则会将输入拉张开为1维向量再进行扩展
1.torch.repeat
x = torch.tensor([1, 2, 3])
x.repeat(4, 2), x.repeat(4, 2).shape, x.repeat(4, 2, 1).shape, x.repeat(2)
输出
(tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]]),
torch.Size([4, 6]),
torch.Size([4, 2, 3]),
tensor([1, 2, 3, 1, 2, 3]))
2.torch.repeat_interleave
x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
[3, 4],
[3, 4]])