pytorch中的张量函数
troch.cat()& torch.stack()
.cat 和 .stack的区别在于 cat会增加现有维度的值,可以理解为续接,stack会新加增加一个维度,可以理解为叠加
x1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int)
x2 = torch.tensor([[12, 22, 32], [22, 32, 42]])
inputs = [x1, x2]
R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
R = torch.stack((x1, x2), dim=0)
print("R:\n", R)
print("R.shape:\n", R.shape)
print(R==R0)# 会报错
输出如下:
R0:
tensor([[11, 21, 31],
[21, 31, 41],
[12, 22, 32],
[22, 32, 42]])
R0.shape:
torch.Size([4, 3])
R:
tensor([[[11, 21, 31],
[21, 31, 41]],
[[12, 22, 32],
[22, 32, 42]]])
R.shape:
torch.Size([2, 2, 3])
torch.nn.cat()和torch.nn.stack()是PyTorch中用于处理张量的两个常用函数,它们在拼接和堆叠张量时有一些区别。让我们来详细了解一下:
torch.nn.cat():
可以理解为“拼接”操作。
可以按行或按列拼接张量。
操作后得到的张量维度不会增加。
要求用来拼接的张量形状匹配(非拼接维度必须一致)。
示例代码:
import torch
import numpy as np
a = torch.from_numpy(np.arange(0, 12).reshape(3, 4))
b = torch.from_numpy(np.arange(0, 12).reshape(3, 4))
# 按行拼接
c0 = torch.cat((a, b), dim=0)
print(f'按行拼接结果:\n{c0}')
# 按列拼接
c1 = torch.cat((a, b), dim=1)
print(f'按列拼接结果:\n{c1}')
按行拼接要求两个张量的列数必须一致,按列拼接要求两个张量的行数必须一致。
torch.nn.stack():
可以理解为“堆叠”操作。
操作后得到的张量会增加一维。
用来堆叠的张量形状必须完全一致。
示例代码:
三种可能的堆叠方式,分别对应dim=0, 1, 2
s0 = torch.stack((a, b), dim=0)
s1 = torch.stack((a, b), dim=1)
s2 = torch.stack((a, b), dim=2)
print(f'按dim=0堆叠结果:\n{s0}')
print(f'按dim=1堆叠结果:\n{s1}')
print(f'按dim=2堆叠结果:\n{s2}')
二维张量有三种可能的堆叠方式,分别对应dim=0, 1, 2。
总结:
torch.nn.cat()用于拼接,不增加维度。
torch.nn.stack()用于堆叠,增加一维。1 2 3
quezze()&unquezze()
在 PyTorch 中,unsqueeze() 和 squeeze() 是两个常用的函数,用于调整张量的维度。让我详细解释一下它们的作用和用法:
torch.squeeze() 函数:
torch.squeeze(A, N) 的作用是减少数组 A 指定位置 N 的维度。
如果不指定位置参数 N,则会删除数组 A 中所有维度为 1 的维度。
举例:
如果数组 A 的维度为(1,1,3),执行 torch.squeeze(A, 1) 后,A 的维度变为(1,3),中间的维度被删除。
注意:
如果指定的维度大于 1,操作将无效。
如果不指定维度 N,将删除所有维度为 1 的维度。
torch.unsqueeze() 函数:
torch.unsqueeze(A, N) 的作用是在数组 A 的指定位置 N 增加一个维度。
例如,对于一个两行三列的数组 A,有三个位置可以增加维度,分别是:
位置 0:(2,3)
位置 1:(2,3)
位置 2:(2,3)
如果执行 torch.unsqueeze(A, 1),数据的维度将变为(2,1,3)。
代码示例:
import torch
# squeeze函数
a = torch.randn(1, 1, 3)
print("原始维度:", a.shape)
b = torch.squeeze(a)
print("去除中间维度:", b.shape)
c = torch.squeeze(a, 0)
print("去除第一个维度:", c.shape)
d = torch.squeeze(a, 1)
print("去除第二个维度:", d.shape)
e = torch.squeeze(a, 2) # 如果去掉第三维,则数不够放了,所以直接保留
print("去除第三个维度:", e.shape)
# unsqueeze函数
f = torch.randn(1, 3)
print("原始维度:", f.shape)
g = torch.unsqueeze(f, 0)
print("在第一个位置增加维度:", g.shape)
h = torch.unsqueeze(f, 1)
print("在第二个位置增加维度:", h.shape)
i = torch.unsqueeze(f, 2)
print("在第三个位置增加维度:", i.shape)
torch.permute()
重新布局