Pytorch 数据类型、张量生成、张量操作
1. 构建数据
1.1 torch.Tensor
数据
1.1.1 torch.Tensor
常用数据类型
pytorch
的基本数据结构为 torch.Tensor
与 numpy
中 numpy.ndarray
数据结构类似,
注意:Tensor
(大写T)表示张量对象,其初始化函数为 torch.tensor()
(小写t)
-
torch.Tensor
常用数据类型torch.float64
ortorch.double
torch.float32
ortorch.float
torch.float16
torch.int64
ortorch.long
torch.int32
ortorch.int
torch.int16
torch.int8
torch.uint8
torch.bool
1.1.2 torch.Tensor
对象的常用属性或方法
数据类型
-
Tensor.dtype
:Tensor
数据类型 -
Tensor.to()
:设置 device;或修改数据类型
复制
Tensor.clone()
:复制Tensor
维度信息
-
Tensor.ndim
orTensor.dim()
:Tensor
的维度 -
Tensor.size
orTensor.size(dim=None)
:Tensor
的尺寸。dim
用于指定特定的维度 -
Tensor.numel
:Tensor
的元素个数(number of elements)
维度、尺寸变换
-
Tensor.view()
orTensor.reshape()
:改变数据的维度和尺寸。 -
Tensor.flatten(start_dim=0, end_dim=-1))
:
不同数据结构之间转换:
-
Tensor.item()
:返回Tensor
对应python
数据类型,对只有一个元素的Tensor
使用 -
Tensor.tolist()
:将Tensor
转换为列表类型(nested list
) -
Tensor.numpy()
:Tensor
转numpy.ndarray
数据类型
自动梯度相关:
-
Tensor.detach()
:返回一个新的 tensor,将其从计算图分离 -
Tensor.data
:返回的 tensor 数值,与原 tensor 共用内存 -
Tensor.grad
:返回的 tensor 梯度数值,与原 tensor 共用内存 -
Tensor.requires_grad_(requires_grad=True)
:改变 Tensor 的requires_grad
属性,是否计算梯度 -
Tensor.requires_grad
:用于指示 Tensor 是否需要计算梯度
设备 device
-
Tensor.to()
:设置 device;或修改数据类型 -
Tensor.device
:返回Tensor
所在的设备(GPU 或 CPU)- 返回值:
'cpu'
或'cuda'
- 返回值:
1.1.3 从 numpy.ndarray
创建 torch.Tensor
torch.from_numpy()
,返回的Tensor
与ndarray
共用内存
1.2 生成随机数
2 torch.Tensor
的常用操作
2.1 计算
torch.bmm(input, mat2)
:批量矩阵乘法(batch matrix-matrix produc),计算原理为:如果 input
为 \((b \times n \times m)\) 的 tensor,mat2
为 \((b \times m \times p)\) tensor, 则 out
为 \((b \times n \times p)\) 的 tensor.
-
参数:
input
和mat2
: 3-D tensors,第1个维度长度必须相等. -
返回:
out
:
2.2 索引、切片
Tensor
的索引切片方式和 numpy.ndarray
几乎是一样的。切片时支持缺省参数和省略号。
常用不规则切片提取:
根据索引修改 Tensor
元素
-
torch.where
-
torch.index_fill
-
torch.masked_fill
2.3 维度变换
-
改变尺寸、维度:
torch.reshape()
-
消除维度:
torch.squeeze(input, dim=None)
如果
Tensor
在某个维度上(通过dim
参数指定)只有一个元素,用此方法可以这个维度。 -
展平维度:
torch.flatten(input, start_dim=0, end_dim=-1)
-
增加一个维度:
torch.unsqueeze(input, dim)
2.4 合并分割
torch.cat(tensors, dim=0)
:连接,不会改变 Tensor
维度
- 参数:
tensors
:tensor
序列,list 或 tuple
torch.stack(tensors, dim=0)
:堆叠,会改变 Tensor
维度
torch.split(tensor, split_size_or_sections, dim=0)
:分割,torch.cat()
逆运算