pytorch中dim的理解

首先先说个结论,就是dim等于那个维度,就是把那个维度给消除了,比如说shape=(2,3,4),如果dim=0,最后的结果的shape=(3,4),如果dim=1,最后的结果的shape=(2,4),如果dim=2的话,最后的结果的shape=(2,3)

首先我们看个例子吧:

import numpy as np
import torch
x = torch.tensor([
        [1,2,3],
        [4,5,6]
    ])
# 我们可以看到"行"是dim=0, "列"是dim=1
print(x.shape)
'''输出
torch.Size([2, 3])
'''
# 但是我们按照dim=0求和, 是按照列相加,最后的shape=(3)
torch.sum(x, dim=0)
'''输出
tensor([5, 7, 9])
'''
# 但是我们按照dim=1求和, 是按照行相加1,最后的shape=(2)
torch.sum(x, dim=1)
'''输出
tensor([ 6, 15])
'''

下面也是类似:

# 看一下三维的
x = torch.tensor([
        [
         [1,2,3],
         [4,5,6]
        ],
        [
         [1,2,3],
         [4,5,6]
        ],
        [
         [1,2,3],
         [4,5,6]
        ]
    ])
# 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
print(x.shape)
'''输出
torch.Size([3, 2, 3])
'''

image

然后对于softmax:

import torch
import numpy as np
import torch.nn.functional as F
 
data = np.array([[0.1, 0.3, 0.6], [1.5,2.1 ,0.55]])
t_data = torch.from_numpy(data)
print(t_data)
print(t_data.shape)
#print(t_data.type())
print("**************************************************************************")
prob = F.softmax(t_data,dim=0) # dim = 0,在列上进行Softmax;dim=1,在行上进行Softmax
print(prob)
print(prob.shape)
#print(prob.type())
print("**************************************************************************")
prob = F.softmax(t_data,dim=1) # dim = 0,在列上进行Softmax;dim=1,在行上进行Softmax
print(prob)
print(prob.shape)
#print(prob.type())

image
当dim=0时候按着列进行softmax(0.1978+0.8022 = 1),当dim=1的时候按着行精选softmax(0.2584+0.3165+0.4260 = 1)

现在是不是有所感悟,你细品

总结:size不等于1,dim 指定沿着某一维度挤压,例如dim=0,就是行被压缩,按着列操作。如dim=1,就是列被压缩,按着行操作。

然后还有就是dim等于那个维度,就是把那个维度给消除了,比如说shape=(2,3,4),如果dim=0,最后的结果的shape=(3,4),如果dim=1,最后的结果的shape=(2,4),如果dim=2的话,最后的结果的shape=(2,3)

posted @   lipu123  阅读(463)  评论(0编辑  收藏  举报
(评论功能已被禁用)
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示