常用api

tensor.repeat和torch.repeat_interleave

  • tensor.repeat()
    a = torch.tensor([[1,2],[3,4]])
    a.repeat((2,1))
    表示在行的维度复制2遍,列维度不变,结果为tensor([[1, 2], [3, 4],[1, 2],[3, 4]])
    a.repeat((2,3))
    结果为tensor([[1, 2, 1, 2, 1, 2],
    [3, 4, 3, 4, 3, 4],
    [1, 2, 1, 2, 1, 2],
    [3, 4, 3, 4, 3, 4]])
  • tensor.repeat_interleave(input, num or tensor, dim)
    当num为整形,则将input中每个元素复制num次,输出为一个list
    torch.repeat_interleave(a, 2), 结果为 tensor([1, 1, 2, 2, 3, 3, 4, 4])
    当为tensor时,则将dim维的数据按照tensor数组复制num次
    torch.repeat_interleave(a,torch.tensor([1, 2]), dim=0),结果为tensor([[1, 2],
    [3, 4],
    [3, 4]])
    torch.repeat_interleave(a,torch.tensor([1, 2]), dim=1),结果为tensor([[1, 2, 2],
    [3, 4, 4]])
    • 综上,repeat适合用来将整个张量进行复制,可以按照张量中某一维; 而 repeat_interleave适合用来将按照某个维度的某行或者列进行复制
posted @ 2021-05-19 23:22  哈哈哈喽喽喽  阅读(47)  评论(0)    收藏  举报