[PyTorch] 如何判定运算维度
实际上无论是几维,方法都是一样。假设以 torch.softmax()
为例:
# 下面运行结果所使用的代码
import torch
import numpy as np
z = np.arange(1, 33).reshape((2, 2, 2, 4))
z = torch.tensor(z, dtype=torch.float32)
# 为了使各元素softmax的结果相差不至于过大,这里简单处理一下,但是与原张量的计算结果并不等价
z = z ** 0.2
torch.softmax(z, dim=0)
一. 三维
为了方便查阅,直接放个图,方法实际上和四维是一样的。
二. 四维
假设有四维tensor:(B, C, H, W),具体为(2, 2, 2, 4)
1. 当dim=0或dim=-4时
最外层 的元素进行运算,也就是最外层 中用逗号隔开的所有对应元素相运算。例如向量中的1、17进行运算,9和25运算。
2. 当dim=1或dim=-3时
同一 batch 的最外层 的所有对应元素进行运算。例如向量中的1和9进行运算,17和25运算。
3. 当dim=2或dim=-2时
同一 Height 的最外层 的所有对应元素进行运算。例如向量中的1和5进行运算,9和13进行运算,17和21进行运算,25和29进行运算。
4. 当dim=3或dim=-1时
同一 W 的最外层 的所有对应元素进行运算。例如向量中的1、2、3、4进行运算,5、6、7、8进行运算,……,29、30、31、32进行运算。
三. 高维
类比四维的情况……
【注】可能我表达确实不行,如看不懂,可看看这篇文章:https://zhuanlan.zhihu.com/p/525276061
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现