理解Pytorch的dim

# Time : 2022.07.06 上午 10:33
# Author : Vandaci(cnfendaki@qq.com)
# File : learning_tensor_dim.py
# Project : LearningPytorch
import torch
import torch.nn as nn
if __name__ == '__main__':
a = torch.tensor([[1., 2.], [3., 4.]])
# dim=2 shape=[2,2] (row,col) dim就是沿着row和col的方向,
# 沿着row的方向就是每列,沿着col的方向就是每行
smax = nn.Softmax(dim=0)
o = smax(a)
print(o)
'''输出结果
tensor([[0.1192, 0.1192],
[0.8808, 0.8808]])
'''
smax = nn.Softmax(dim=1)
o = smax(a)
print(o)
'''输出结果
tensor([[0.2689, 0.7311],
[0.2689, 0.7311]])
'''
# 对于dim>=3,亦可照此推断
b = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8.]]])
smax = nn.Softmax(dim=0) # 沿着batch的方向:1对应5
o = smax(b)
print(o)
'''输出结果
tensor([[[0.0180, 0.0180],
[0.0180, 0.0180]],
[[0.9820, 0.9820],
[0.9820, 0.9820]]])
'''
pass
posted @   Vandaci  阅读(66)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
点击右上角即可分享
微信分享提示