[PyTorch] 如何判定运算维度
实际上无论是几维,方法都是一样。假设以 torch.softmax()
为例:
# 下面运行结果所使用的代码
import torch
import numpy as np
z = np.arange(1, 33).reshape((2, 2, 2, 4))
z = torch.tensor(z, dtype=torch.float32)
# 为了使各元素softmax的结果相差不至于过大,这里简单处理一下,但是与原张量的计算结果并不等价
z = z ** 0.2
torch.softmax(z, dim=0)
一. 三维
为了方便查阅,直接放个图,方法实际上和四维是一样的。
二. 四维
假设有四维tensor:(B, C, H, W),具体为(2, 2, 2, 4)
\[\begin{align}
[[[&[\space1, \quad 2, \quad 3, \quad 4], \\
&[\space5, \quad 6, \quad 7, \quad 8]], \\
[&[\space9, \space\space\space 10, \space\space 11, \space\space 12], \\
&[13, \space\space 14, \space\space 15, \space\space 16]]], \\
[&[17, \space\space 18, \space\space 19, \space\space 20], \\
&[21, \space\space 22, \space\space 23, \space\space 24]], \\
[&[25, \space\space 26, \space\space 27, \space\space 28], \\
&[29, \space\space 30, \space\space 31, \space\space 32]]]]
\end{align}
\]
1. 当dim=0或dim=-4时
最外层 \([\space]\) 的元素进行运算,也就是最外层 \([\space]\) 中用逗号隔开的所有对应元素相运算。例如向量中的1、17进行运算,9和25运算。
2. 当dim=1或dim=-3时
同一 batch 的最外层 \([\space]\) 的所有对应元素进行运算。例如向量中的1和9进行运算,17和25运算。
3. 当dim=2或dim=-2时
同一 Height 的最外层 \([\space]\) 的所有对应元素进行运算。例如向量中的1和5进行运算,9和13进行运算,17和21进行运算,25和29进行运算。
4. 当dim=3或dim=-1时
同一 W 的最外层 \([\space]\) 的所有对应元素进行运算。例如向量中的1、2、3、4进行运算,5、6、7、8进行运算,……,29、30、31、32进行运算。
三. 高维
类比四维的情况……
【注】可能我表达确实不行,如看不懂,可看看这篇文章:https://zhuanlan.zhihu.com/p/525276061