[PyTorch] 如何判定运算维度

实际上无论是几维,方法都是一样。假设以 torch.softmax() 为例:

# 下面运行结果所使用的代码
import torch
import numpy as np
z = np.arange(1, 33).reshape((2, 2, 2, 4))
z = torch.tensor(z, dtype=torch.float32)
# 为了使各元素softmax的结果相差不至于过大,这里简单处理一下,但是与原张量的计算结果并不等价
z = z ** 0.2
torch.softmax(z, dim=0) 

一. 三维

为了方便查阅,直接放个图,方法实际上和四维是一样的。

image

二. 四维

假设有四维tensor:(B, C, H, W),具体为(2, 2, 2, 4)

\[\begin{align} [[[&[\space1, \quad 2, \quad 3, \quad 4], \\ &[\space5, \quad 6, \quad 7, \quad 8]], \\ [&[\space9, \space\space\space 10, \space\space 11, \space\space 12], \\ &[13, \space\space 14, \space\space 15, \space\space 16]]], \\ [&[17, \space\space 18, \space\space 19, \space\space 20], \\ &[21, \space\space 22, \space\space 23, \space\space 24]], \\ [&[25, \space\space 26, \space\space 27, \space\space 28], \\ &[29, \space\space 30, \space\space 31, \space\space 32]]]] \end{align} \]


1. 当dim=0或dim=-4时

最外层 \([\space]\) 的元素进行运算,也就是最外层 \([\space]\) 中用逗号隔开的所有对应元素相运算。例如向量中的117进行运算,925运算。

image


2. 当dim=1或dim=-3时

同一 batch 的最外层 \([\space]\) 的所有对应元素进行运算。例如向量中的19进行运算,1725运算。

image


3. 当dim=2或dim=-2时

同一 Height 的最外层 \([\space]\) 的所有对应元素进行运算。例如向量中的15进行运算,913进行运算,1721进行运算,2529进行运算。

image


4. 当dim=3或dim=-1时

同一 W 的最外层 \([\space]\) 的所有对应元素进行运算。例如向量中的1234进行运算,5678进行运算,……,29303132进行运算。

image

三. 高维

类比四维的情况……

【注】可能我表达确实不行,如看不懂,可看看这篇文章:https://zhuanlan.zhihu.com/p/525276061

posted @ 2023-12-05 17:32  小贼的自由  阅读(22)  评论(0编辑  收藏  举报