torch.Tensor.repeat和numpy.repeat区别

函数声明

  • numpy.repeat(a, repeats, axis=None)

numpy.repeat需要指定维度,repeats为一个数,数组在axis维度上重复repeats倍。

  • torch.Tensor.repeat(*sizes)

size是一个元组,元素个数同tensor.shape元素个数,size中元素为tensor相应维度重复倍数。

代码实战

import torch
import numpy as np

a = torch.rand(2, 3, 1)
b = a.repeat((4, 1, 1))

# a.shape (2, 3, 1)  b.shape (8, 3, 1)

a = torch.rand(2, 3, 1)
b = a.repeat((1, 2, 1))

# a.shape (2, 3, 1)  b.shape (2, 6, 1)

a = np.zeros((2, 3, 1))
b = a.repeat(4, axis=0)

# a.shape (2, 3, 1)  b.shape (8, 3, 1)

a = np.zeros((2, 3, 1))
b = a.repeat(2, axis=1)

# a.shape (2, 3, 1)  b.shape (2, 6, 1)

总结

numpy.repeat每次只能在一个维度上重复,而torch.repeat需要指定所有维度上的重复倍数。

posted @ 2022-03-23 15:13  Js2Hou  阅读(222)  评论(0编辑  收藏  举报