[PyTorch] 如何判定运算维度

实际上无论是几维,方法都是一样。假设以 torch.softmax() 为例:

# 下面运行结果所使用的代码 import torch import numpy as np z = np.arange(1, 33).reshape((2, 2, 2, 4)) z = torch.tensor(z, dtype=torch.float32) # 为了使各元素softmax的结果相差不至于过大,这里简单处理一下,但是与原张量的计算结果并不等价 z = z ** 0.2 torch.softmax(z, dim=0)

一. 三维

为了方便查阅,直接放个图,方法实际上和四维是一样的。

image

二. 四维

假设有四维tensor:(B, C, H, W),具体为(2, 2, 2, 4)

(1)[[[[ 1,2,3,4],(2)[ 5,6,7,8]],(3)[[ 9,   10,  11,  12],(4)[13,  14,  15,  16]]],(5)[[17,  18,  19,  20],(6)[21,  22,  23,  24]],(7)[[25,  26,  27,  28],(8)[29,  30,  31,  32]]]]


1. 当dim=0或dim=-4时

最外层 [ ] 的元素进行运算,也就是最外层 [ ] 中用逗号隔开的所有对应元素相运算。例如向量中的117进行运算,925运算。

image


2. 当dim=1或dim=-3时

同一 batch 的最外层 [ ] 的所有对应元素进行运算。例如向量中的19进行运算,1725运算。

image


3. 当dim=2或dim=-2时

同一 Height 的最外层 [ ] 的所有对应元素进行运算。例如向量中的15进行运算,913进行运算,1721进行运算,2529进行运算。

image


4. 当dim=3或dim=-1时

同一 W 的最外层 [ ] 的所有对应元素进行运算。例如向量中的1234进行运算,5678进行运算,……,29303132进行运算。

image

三. 高维

类比四维的情况……

【注】可能我表达确实不行,如看不懂,可看看这篇文章:https://zhuanlan.zhihu.com/p/525276061

posted @   小贼的自由  阅读(35)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现
点击右上角即可分享
微信分享提示