6 Tensor.squeeze & Tensor.unsqueeze
Tensor.squeeze & Tensor.unsqueeze
1 Tensor的维度
张量的的定义:一个 n 维的张量就是一维数组中的所有元素都是 n - 1 维的张量。
举例说明:
import torch
a = torch.rand(3)
b = torch.rand(2, 3)
c = torch.rand(2, 2, 3)
d = torch.rand(2, 2, 2, 3)
print('长度为3的一维张量a:\n', a,
'\n\n2个3列的二维张量b:\n', b,
'\n\n2个2行,3列的三维张量c:\n', c,
'\n\n2个 2个2行3列的三维张量的 四维张量d:\n', d)
Output:
长度为3的一维张量a:
tensor([0.3920, 0.3974, 0.0011])
2个3列的二维张量b:
tensor([[0.2963, 0.4094, 0.2361],
[0.3112, 0.7413, 0.3581]])
2个2行,3列的三维张量c:
tensor([[[0.0186, 0.6931, 0.7656],
[0.2139, 0.4942, 0.5294]],
[[0.2662, 0.5021, 0.5917],
[0.2941, 0.1025, 0.5816]]])
2个 2个2行3列的三维张量的 四维张量d:
tensor([[[[0.3872, 0.3498, 0.7085],
[0.0278, 0.6373, 0.4800]],
[[0.2813, 0.4828, 0.4292],
[0.4180, 0.2158, 0.4107]]],
[[[0.9930, 0.9733, 0.6893],
[0.3464, 0.2090, 0.6614]],
[[0.6189, 0.5298, 0.2926],
[0.4828, 0.1475, 0.6485]]]])
即 d 中有 2 个 c,c 中有 2 个 b,b 中有 2 个 a。
2 tensor.squeeze()
作用:降维。 起因:加速运算
torch.squeeze(input,
dim=None, # dim 从0算起,将要挤压的维度【必填】
out=None)
如果dim指定的维度的值为1,则将该维度删除,若指定的维度值不为1,则返回原来的tensor
举例:
import torch
x = torch.rand(2,1,3)
print(x, x.shape)
print(x.squeeze(1), x.squeeze(1).shape)
print(x.squeeze(2), x.squeeze(2).shape)
Output:
tensor([[[0.5902, 0.5582, 0.1262]],
[[0.0488, 0.0957, 0.3213]]]) torch.Size([2, 1, 3])
tensor([[0.5902, 0.5582, 0.1262],
[0.0488, 0.0957, 0.3213]]) torch.Size([2, 3])
tensor([[[0.5902, 0.5582, 0.1262]],
[[0.0488, 0.0957, 0.3213]]]) torch.Size([2, 1, 3])
x.shape=[2, 1, 3] , 第一维度的值为1, 因此 x.squeeze(dim=1) 的输出会将第一维度去掉,其输出 shape=[2,3], 第二维度值不为1,因此 x.squeeze(dim=2) 输出tensor的shape不变。
3 tensor.unsqueeze()
作用:升维。 起因:适配矩阵运算
torch.unsqueeze(
input,
dim, # dim 从0算起,将要扩增的维度【必填】
out=None)
如果dim为负,则将会被转化dim+input.dim()+1,即在最后增一维
另,unsqueeze_ 和 unsqueeze 实现一样的功能, 区别在于 unsqueeze_ 是 in_place 操作,即 unsqueeze 不会对使用 unsqueeze 的 tensor 进行改变, 想要获取 unsqueeze 后的值必须赋予个新值, unsqueeze_ 则会对自己改变。
简单来说,unsqueeze_ 是 unsqueeze在个体上的函数实现。
举例:
import torch
x = torch.rand(2,3)
print(x, "x.shape:", x.shape)
y = torch.unsqueeze(x, 1)
print(y, "y.shape:", y.shape)
z = x.unsqueeze_(2)
print(z, "z.shape:", z.shape)
Output:
tensor([[0.8945, 0.3175, 0.4764],
[0.6368, 0.8772, 0.4863]]) x.shape: torch.Size([2, 3])
tensor([[[0.8945, 0.3175, 0.4764]],
[[0.6368, 0.8772, 0.4863]]]) y.shape: torch.Size([2, 1, 3])
tensor([[[0.8945],
[0.3175],
[0.4764]],
[[0.6368],
[0.8772],
[0.4863]]]) z.shape: torch.Size([2, 3, 1])