PyTorch 常用函数备忘

PyTorch 常用函数备忘

# basic operation
x: torch.Tensor
x.shape -> torch.Size
x.ndim -> int # 轴数
x.T
x.numel() -> int # total size
x.reshape(*shape) -> Tensor
x.sum(), x.mean()
x.sum(axis=int) # 沿某维求和,会将轴数减少1
x.sum(axis=int, keepdims=True) # 沿某维求和,将该轴长度保留为1
# x.sum() == x.sum(axis=[0, 1])
# x.mean() == x.sum()/x.numel()
x.cumsum(axis=int) # 沿某个轴计算x元素的累积总和,不会沿任何轴降低输入张量的维度,axis参数必须指定
x[slice] # same as numpy
len(x) == x.shape[0]
x.clone()
torch.norm(x) -> Tensor0D # L2范数,可用于向量和矩阵(弗罗⻉尼乌斯范数)
torch.abs(x).sum() # L1范数

# constructor
torch.arange(int) -> Tensor
torch.zeros(tuple) -> Tensor
torch.ones(tuple) -> Tensor
torch.randn(*shape) -> Tensor # 每个元素都从均值为0、标准差为1的标准高斯(正态)分布中随机采样
torch.tensor([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]]) # from list
torch.tensor(numpy.ndarray) <=> x.numpy()
torch.zeros_like(Tensor) -> Tensor

# arithmetical operation
# Y += X better than Y = Y+X
a: torch.Tensor
b: torch.Tensor
a+b, a-b, a*b, a/b, a**b
# a*b 哈达玛积,对应位相乘
a==b, a<b, a>b
torch.exp(a)


# matrix operation
a: torch.Tensor
b: torch.Tensor
torch.cat(tuple[Tensor], dim=int) # 按dim维连接各矩阵
torch.dot(Tensor1D, Tensor1D) -> Tensor0D # 仅支持1D Tensor,输出0D Tensor
torch.mv(Tensor2D, Tensor1D) -> Tensor1D
torch.mm(Tensor2D, Tensor2D) -> Tensor2D

广播机制详解

posted @ 2021-12-01 22:42  Terrasse  阅读(60)  评论(1编辑  收藏  举报