📂note
🔖python
2023-08-28 16:26阅读: 93评论: 0推荐: 0

[note] pytorch的几种维度操作方式比对

prologue

今天看代码在想 torch.unbind + torch.cattorch.reshape 的区别,直观上来看reshape似乎更便利。

chatgpt

问题

x is a tensor of three dimension, what is the difference between torch.cat(torch.unbind(x, dim=1), dim=0) and torch.reshape(x, (x.shape[0]*x.shape[1], -1))

回答

Both torch.cat(torch.unbind(x, dim=1), dim=0) and torch.reshape(x, (x.shape[0]*x.shape[1], -1)) are operations that reshape and reorganize a 3-dimensional tensor x, but they achieve this in slightly different ways.

  1. torch.cat(torch.unbind(x, dim=1), dim=0):

    • This operation involves two steps. First, it unbinds the tensor x along its second dimension (dim=1), which effectively splits the tensor into a list of tensors along that dimension. Each tensor in the list corresponds to one "column" from the original 3D tensor.
    • Then, it uses torch.cat to concatenate these split tensors along the specified dimension (dim=0), effectively stacking them on top of each other to create a new tensor with a reshaped layout.
    • The final shape of the tensor will be (x.shape[0] * x.shape[1], x.shape[2]).
  2. torch.reshape(x, (x.shape[0]*x.shape[1], -1)):

    • This operation directly reshapes the entire tensor x without splitting or concatenating it.
    • The desired shape is specified as a tuple (x.shape[0]*x.shape[1], -1), where -1 in the second dimension means that PyTorch will automatically calculate the size of that dimension to preserve the total number of elements in the tensor.
    • The final shape of the tensor will also be (x.shape[0] * x.shape[1], x.shape[2]).

In summary, both operations reshape the tensor x into a 2-dimensional tensor with the same number of elements, but they approach it differently. The first approach involves splitting and concatenating along dimensions, while the second approach directly reshapes the tensor. The end result will be the same in terms of shape and arrangement of elements.

测试

本来是相信chatgpt的,但后面越想越不对,做了如下实验:

import torch
from einops import rearrange

x = torch.randint(0, 100, (3, 2, 1))
print('origin: ',x.squeeze())
# 对比四种不同的维度操作方式
unbind1 = torch.cat(torch.unbind(x, dim=1), dim=0)
unbind2 = torch.cat(torch.unbind(x, dim=0), dim=0)
reshape1 = torch.reshape(x, (x.shape[0]*x.shape[1], -1))
reshape2 = torch.reshape(x, (-1, x.shape[-1]))
view1 = x.view(x.shape[0]*x.shape[1], -1)
view2 = x.view(-1, x.shape[-1])
rearrange1 = rearrange(x, 'b p c -> (b p) c')
rearrange2 = rearrange(x, 'b p c -> (p b) c')

# cat+unbind与其他三者结果均不同
print(f'{unbind1.squeeze()=}')
print(f'{unbind2.squeeze()=}')
print(f'{reshape1.squeeze()=}')
print(f'{reshape2.squeeze()=}')
print(f'{view1.squeeze()=}')
print(f'{view2.squeeze()=}')
print(f'{rearrange1.squeeze()=}')
print(f'{rearrange2.squeeze()=}')

# cat+unbind的结果(a)就无法像c一样用rearrange变回x
x2 = rearrange(view1, '(b p) c -> b p c', b=3, p=2)
x3 = rearrange(unbind1, '(b p) c -> b p c', b=3, p=2)
print(f'x==x2: {(x==x2).squeeze()}')
print(f'x==x3: {(x==x3).squeeze()}')
rearrange1_ = rearrange(rearrange1, '(b p) c -> b p c', b=3, p=2)
rearrange2_ = rearrange(rearrange2, '(b p) c -> b p c', b=3, p=2)
print(f'{rearrange1_.squeeze()=}')
print(f'{rearrange2_.squeeze()=}')

输出:

origin:  tensor([[24, 19],
        [52, 89],
        [57, 66]])
unbind1.squeeze()=tensor([24, 52, 57, 19, 89, 66])
unbind2.squeeze()=tensor([24, 19, 52, 89, 57, 66])
reshape1.squeeze()=tensor([24, 19, 52, 89, 57, 66])
reshape2.squeeze()=tensor([24, 19, 52, 89, 57, 66])
view1.squeeze()=tensor([24, 19, 52, 89, 57, 66])
view2.squeeze()=tensor([24, 19, 52, 89, 57, 66])
rearrange1.squeeze()=tensor([24, 19, 52, 89, 57, 66])
rearrange2.squeeze()=tensor([24, 52, 57, 19, 89, 66])
x==x2: tensor([[True, True],
        [True, True],
        [True, True]])
x==x3: tensor([[ True, False],
        [False, False],
        [False,  True]])
rearrange1_.squeeze()=tensor([[24, 19],
        [52, 89],
        [57, 66]])
rearrange2_.squeeze()=tensor([[24, 52],
        [57, 19],
        [89, 66]])

epilogue

总的来说,rearrange 直观且灵活,假设x=[[97, 14], [ 0, 16], [55, 62]]rearrange(x, 'b p -> (p b)')torch.cat(torch.unbind(x, dim=1), dim=0) 表示将x按列(axis=1)拆开然后拼合,得到[97, 0, 55, 14, 16, 62]

rearrange(x, 'b p -> (b p)')torch.cat(torch.unbind(x, dim=0), dim=0) 以及 reshape / view 则是将x按行(axis=0)拆开再拼合,得到[97, 14, 0, 16, 55, 62]

此外rearrange结果可逆,需要用相应的pattern进行操作,如 '(b p) c -> b p c' 对应 'b p c -> (b p) c'

本文作者:心有所向,日复一日,必有精进

本文链接:https://www.cnblogs.com/Stareven233/p/17662708.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   NoNoe  阅读(93)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
💬
评论
📌
收藏
💗
关注
👍
推荐
🚀
回顶
收起
  1. 1 Relaxロウきゅーぶ 渡辺剛
  2. 2 カントリーマーチ 栗コーダーカルテット
  3. 3 BGM-M7 かみむら周平
  4. 4 八百万の風が吹く Foxtail-Grass Studio
  5. 5 雲流れ Foxtail-Grass Studio
  6. 6 Melody 梶浦由記
  7. 7 ロック风アレンジ Angel Beats
  8. 8 ヨスガノソラ メインテーマ -遠い空へ- Bruno Wen-li
  9. 9 Servante du feu Matthieu Ladouce
  10. 10 Lost my pieces (Piano Ver.) 橋本由香利
  11. 11 潮鳴り 折戸伸治
  12. 12 雪風 Foxtail-Grass Studio
  13. 13 Bloom of Youth 清水淳一
  14. 14 落月随山隐 饭碗的彼岸,夜莺与玫瑰
  15. 15 Autumn Journey Eric Chiryoku
  16. 16 Alpha C418
  17. 17 Money之歌 神楽Mea
カントリーマーチ - 栗コーダーカルテット
00:00 / 00:00
An audio error has occurred, player will skip forward in 2 seconds.

暂无歌词