pytorch中dim的理解
首先先说个结论,就是dim等于那个维度,就是把那个维度给消除了,比如说shape=(2,3,4),如果dim=0,最后的结果的shape=(3,4),如果dim=1,最后的结果的shape=(2,4),如果dim=2的话,最后的结果的shape=(2,3)
首先我们看个例子吧:
import numpy as np
import torch
x = torch.tensor([
[1,2,3],
[4,5,6]
])
# 我们可以看到"行"是dim=0, "列"是dim=1
print(x.shape)
'''输出
torch.Size([2, 3])
'''
# 但是我们按照dim=0求和, 是按照列相加,最后的shape=(3)
torch.sum(x, dim=0)
'''输出
tensor([5, 7, 9])
'''
# 但是我们按照dim=1求和, 是按照行相加1,最后的shape=(2)
torch.sum(x, dim=1)
'''输出
tensor([ 6, 15])
'''
下面也是类似:
# 看一下三维的
x = torch.tensor([
[
[1,2,3],
[4,5,6]
],
[
[1,2,3],
[4,5,6]
],
[
[1,2,3],
[4,5,6]
]
])
# 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
print(x.shape)
'''输出
torch.Size([3, 2, 3])
'''
然后对于softmax:
import torch
import numpy as np
import torch.nn.functional as F
data = np.array([[0.1, 0.3, 0.6], [1.5,2.1 ,0.55]])
t_data = torch.from_numpy(data)
print(t_data)
print(t_data.shape)
#print(t_data.type())
print("**************************************************************************")
prob = F.softmax(t_data,dim=0) # dim = 0,在列上进行Softmax;dim=1,在行上进行Softmax
print(prob)
print(prob.shape)
#print(prob.type())
print("**************************************************************************")
prob = F.softmax(t_data,dim=1) # dim = 0,在列上进行Softmax;dim=1,在行上进行Softmax
print(prob)
print(prob.shape)
#print(prob.type())
当dim=0时候按着列进行softmax(0.1978+0.8022 = 1),当dim=1的时候按着行精选softmax(0.2584+0.3165+0.4260 = 1)
现在是不是有所感悟,你细品
总结:size不等于1,dim 指定沿着某一维度挤压,例如dim=0,就是行被压缩,按着列操作。如dim=1,就是列被压缩,按着行操作。
然后还有就是dim等于那个维度,就是把那个维度给消除了,比如说shape=(2,3,4),如果dim=0,最后的结果的shape=(3,4),如果dim=1,最后的结果的shape=(2,4),如果dim=2的话,最后的结果的shape=(2,3)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律