[note] pytorch的几种维度操作方式比对
prologue
今天看代码在想 torch.unbind + torch.cat
与 torch.reshape
的区别,直观上来看reshape似乎更便利。
chatgpt
问题
x is a tensor of three dimension, what is the difference between torch.cat(torch.unbind(x, dim=1), dim=0)
and torch.reshape(x, (x.shape[0]*x.shape[1], -1))
回答
Both torch.cat(torch.unbind(x, dim=1), dim=0)
and torch.reshape(x, (x.shape[0]*x.shape[1], -1))
are operations that reshape and reorganize a 3-dimensional tensor x
, but they achieve this in slightly different ways.
-
torch.cat(torch.unbind(x, dim=1), dim=0)
:- This operation involves two steps. First, it unbinds the tensor
x
along its second dimension (dim=1
), which effectively splits the tensor into a list of tensors along that dimension. Each tensor in the list corresponds to one "column" from the original 3D tensor. - Then, it uses
torch.cat
to concatenate these split tensors along the specified dimension (dim=0
), effectively stacking them on top of each other to create a new tensor with a reshaped layout. - The final shape of the tensor will be
(x.shape[0] * x.shape[1], x.shape[2])
.
- This operation involves two steps. First, it unbinds the tensor
-
torch.reshape(x, (x.shape[0]*x.shape[1], -1))
:- This operation directly reshapes the entire tensor
x
without splitting or concatenating it. - The desired shape is specified as a tuple
(x.shape[0]*x.shape[1], -1)
, where-1
in the second dimension means that PyTorch will automatically calculate the size of that dimension to preserve the total number of elements in the tensor. - The final shape of the tensor will also be
(x.shape[0] * x.shape[1], x.shape[2])
.
- This operation directly reshapes the entire tensor
In summary, both operations reshape the tensor x
into a 2-dimensional tensor with the same number of elements, but they approach it differently. The first approach involves splitting and concatenating along dimensions, while the second approach directly reshapes the tensor. The end result will be the same in terms of shape and arrangement of elements.
测试
本来是相信chatgpt的,但后面越想越不对,做了如下实验:
import torch
from einops import rearrange
x = torch.randint(0, 100, (3, 2, 1))
print('origin: ',x.squeeze())
# 对比四种不同的维度操作方式
unbind1 = torch.cat(torch.unbind(x, dim=1), dim=0)
unbind2 = torch.cat(torch.unbind(x, dim=0), dim=0)
reshape1 = torch.reshape(x, (x.shape[0]*x.shape[1], -1))
reshape2 = torch.reshape(x, (-1, x.shape[-1]))
view1 = x.view(x.shape[0]*x.shape[1], -1)
view2 = x.view(-1, x.shape[-1])
rearrange1 = rearrange(x, 'b p c -> (b p) c')
rearrange2 = rearrange(x, 'b p c -> (p b) c')
# cat+unbind与其他三者结果均不同
print(f'{unbind1.squeeze()=}')
print(f'{unbind2.squeeze()=}')
print(f'{reshape1.squeeze()=}')
print(f'{reshape2.squeeze()=}')
print(f'{view1.squeeze()=}')
print(f'{view2.squeeze()=}')
print(f'{rearrange1.squeeze()=}')
print(f'{rearrange2.squeeze()=}')
# cat+unbind的结果(a)就无法像c一样用rearrange变回x
x2 = rearrange(view1, '(b p) c -> b p c', b=3, p=2)
x3 = rearrange(unbind1, '(b p) c -> b p c', b=3, p=2)
print(f'x==x2: {(x==x2).squeeze()}')
print(f'x==x3: {(x==x3).squeeze()}')
rearrange1_ = rearrange(rearrange1, '(b p) c -> b p c', b=3, p=2)
rearrange2_ = rearrange(rearrange2, '(b p) c -> b p c', b=3, p=2)
print(f'{rearrange1_.squeeze()=}')
print(f'{rearrange2_.squeeze()=}')
输出:
origin: tensor([[24, 19],
[52, 89],
[57, 66]])
unbind1.squeeze()=tensor([24, 52, 57, 19, 89, 66])
unbind2.squeeze()=tensor([24, 19, 52, 89, 57, 66])
reshape1.squeeze()=tensor([24, 19, 52, 89, 57, 66])
reshape2.squeeze()=tensor([24, 19, 52, 89, 57, 66])
view1.squeeze()=tensor([24, 19, 52, 89, 57, 66])
view2.squeeze()=tensor([24, 19, 52, 89, 57, 66])
rearrange1.squeeze()=tensor([24, 19, 52, 89, 57, 66])
rearrange2.squeeze()=tensor([24, 52, 57, 19, 89, 66])
x==x2: tensor([[True, True],
[True, True],
[True, True]])
x==x3: tensor([[ True, False],
[False, False],
[False, True]])
rearrange1_.squeeze()=tensor([[24, 19],
[52, 89],
[57, 66]])
rearrange2_.squeeze()=tensor([[24, 52],
[57, 19],
[89, 66]])
epilogue
总的来说,rearrange
直观且灵活,假设x=[[97, 14], [ 0, 16], [55, 62]]
,rearrange(x, 'b p -> (p b)')
和 torch.cat(torch.unbind(x, dim=1), dim=0)
表示将x按列(axis=1)拆开然后拼合,得到[97, 0, 55, 14, 16, 62]
;
而 rearrange(x, 'b p -> (b p)')
、 torch.cat(torch.unbind(x, dim=0), dim=0)
以及 reshape / view
则是将x按行(axis=0)拆开再拼合,得到[97, 14, 0, 16, 55, 62]
。
此外rearrange
结果可逆,需要用相应的pattern进行操作,如 '(b p) c -> b p c'
对应 'b p c -> (b p) c'
本文作者:心有所向,日复一日,必有精进
本文链接:https://www.cnblogs.com/Stareven233/p/17662708.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步