Pytorch的repeat函数

repeat可以完成指定维度上的复制

import torch
a = torch.randn(3, 2)
a
tensor([[ 1.4169,  0.2761],
        [ 1.2145, -2.0269],
        [ 1.1322, -0.7117]])

 

b = a.repeat(1,2)
b,b.size() 
(tensor([[ 1.4169,  0.2761,  1.4169,  0.2761],
         [ 1.2145, -2.0269,  1.2145, -2.0269],
         [ 1.1322, -0.7117,  1.1322, -0.7117]]),
 torch.Size([3, 4]))

 

posted @ 2022-04-04 18:53  vv_869  阅读(427)  评论(0编辑  收藏  举报