torch.Tensor.repeat和numpy.repeat区别
函数声明
numpy.repeat(a, repeats, axis=None)
numpy.repeat需要指定维度,repeats为一个数,数组在axis维度上重复repeats倍。
torch.Tensor.repeat(*sizes)
size是一个元组,元素个数同tensor.shape元素个数,size中元素为tensor相应维度重复倍数。
代码实战
import torch
import numpy as np
a = torch.rand(2, 3, 1)
b = a.repeat((4, 1, 1))
# a.shape (2, 3, 1) b.shape (8, 3, 1)
a = torch.rand(2, 3, 1)
b = a.repeat((1, 2, 1))
# a.shape (2, 3, 1) b.shape (2, 6, 1)
a = np.zeros((2, 3, 1))
b = a.repeat(4, axis=0)
# a.shape (2, 3, 1) b.shape (8, 3, 1)
a = np.zeros((2, 3, 1))
b = a.repeat(2, axis=1)
# a.shape (2, 3, 1) b.shape (2, 6, 1)
总结
numpy.repeat每次只能在一个维度上重复,而torch.repeat需要指定所有维度上的重复倍数。