pytorch中的张量函数

troch.cat()& torch.stack()

.cat 和 .stack的区别在于 cat会增加现有维度的值,可以理解为续接,stack会新加增加一个维度,可以理解为叠加

x1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int)
x2 = torch.tensor([[12, 22, 32], [22, 32, 42]])
inputs = [x1, x2]
R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
R = torch.stack((x1, x2), dim=0)
print("R:\n", R)
print("R.shape:\n", R.shape)
print(R==R0)# 会报错

输出如下:

R0:
 tensor([[11, 21, 31],
        [21, 31, 41],
        [12, 22, 32],
        [22, 32, 42]])
R0.shape:
 torch.Size([4, 3])
R:
 tensor([[[11, 21, 31],
         [21, 31, 41]],

        [[12, 22, 32],
         [22, 32, 42]]])
R.shape:
 torch.Size([2, 2, 3])

torch.nn.cat()和torch.nn.stack()是PyTorch中用于处理张量的两个常用函数,它们在拼接和堆叠张量时有一些区别。让我们来详细了解一下:

torch.nn.cat():
可以理解为“拼接”操作。
可以按行或按列拼接张量。
操作后得到的张量维度不会增加。
要求用来拼接的张量形状匹配(非拼接维度必须一致)。
示例代码:

import torch
import numpy as np

a = torch.from_numpy(np.arange(0, 12).reshape(3, 4))
b = torch.from_numpy(np.arange(0, 12).reshape(3, 4))

# 按行拼接
c0 = torch.cat((a, b), dim=0)
print(f'按行拼接结果:\n{c0}')

# 按列拼接
c1 = torch.cat((a, b), dim=1)
print(f'按列拼接结果:\n{c1}')

按行拼接要求两个张量的列数必须一致,按列拼接要求两个张量的行数必须一致。
torch.nn.stack():
可以理解为“堆叠”操作。
操作后得到的张量会增加一维。
用来堆叠的张量形状必须完全一致。
示例代码:
三种可能的堆叠方式,分别对应dim=0, 1, 2


s0 = torch.stack((a, b), dim=0)
s1 = torch.stack((a, b), dim=1)
s2 = torch.stack((a, b), dim=2)

print(f'按dim=0堆叠结果:\n{s0}')
print(f'按dim=1堆叠结果:\n{s1}')
print(f'按dim=2堆叠结果:\n{s2}')

二维张量有三种可能的堆叠方式,分别对应dim=0, 1, 2。
总结:

torch.nn.cat()用于拼接,不增加维度。
torch.nn.stack()用于堆叠,增加一维。1 2 3

quezze()&unquezze()

在 PyTorch 中,unsqueeze() 和 squeeze() 是两个常用的函数,用于调整张量的维度。让我详细解释一下它们的作用和用法:

torch.squeeze() 函数:
torch.squeeze(A, N) 的作用是减少数组 A 指定位置 N 的维度。
如果不指定位置参数 N,则会删除数组 A 中所有维度为 1 的维度。
举例:
如果数组 A 的维度为(1,1,3),执行 torch.squeeze(A, 1) 后,A 的维度变为(1,3),中间的维度被删除。
注意:
如果指定的维度大于 1,操作将无效。
如果不指定维度 N,将删除所有维度为 1 的维度。
torch.unsqueeze() 函数:
torch.unsqueeze(A, N) 的作用是在数组 A 的指定位置 N 增加一个维度。
例如,对于一个两行三列的数组 A,有三个位置可以增加维度,分别是:
位置 0:(2,3)
位置 1:(2,3)
位置 2:(2,3)
如果执行 torch.unsqueeze(A, 1),数据的维度将变为(2,1,3)。
代码示例:

import torch

# squeeze函数
a = torch.randn(1, 1, 3)
print("原始维度:", a.shape)
b = torch.squeeze(a)
print("去除中间维度:", b.shape)
c = torch.squeeze(a, 0)
print("去除第一个维度:", c.shape)
d = torch.squeeze(a, 1)
print("去除第二个维度:", d.shape)
e = torch.squeeze(a, 2)  # 如果去掉第三维,则数不够放了,所以直接保留
print("去除第三个维度:", e.shape)

# unsqueeze函数
f = torch.randn(1, 3)
print("原始维度:", f.shape)
g = torch.unsqueeze(f, 0)
print("在第一个位置增加维度:", g.shape)
h = torch.unsqueeze(f, 1)
print("在第二个位置增加维度:", h.shape)
i = torch.unsqueeze(f, 2)
print("在第三个位置增加维度:", i.shape)

torch.permute()

重新布局

Reference

posted @ 2024-03-26 11:41  光辉233  阅读(5)  评论(0编辑  收藏  举报