torch.nn.functional中softmax的作用及其参数说明

 参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/functional/#_1

class torch.nn.Softmax(input, dim)

或:

torch.nn.functional.softmax(input, dim)

 

对n维输入张量运用Softmax函数,将张量的每个元素缩放到(0,1)区间且和为1。Softmax函数定义如下:

参数:

  dim:指明维度,dim=0表示按列计算;dim=1表示按行计算。默认dim的方法已经弃用了,最好声明dim,否则会警告:

UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

shape:

  • 输入:(N, L)
  • 输出:(N, L)

返回结果是一个与输入维度dim相同的张量,每个元素的取值范围在(0,1)区间。

例子:

复制代码
import torch

from torch import nn
from torch import autograd

m = nn.Softmax()
input = autograd.Variable(torch.randn(2, 3))
print(input)
print(m(input))
复制代码

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py 
tensor([[ 0.2854,  0.1708,  0.4308],
        [-0.1983,  2.0705,  0.1549]])
test.py:9: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  print(m(input))
tensor([[0.3281, 0.2926, 0.3794],
        [0.0827, 0.7996, 0.1177]])

可见默认按行计算,即dim=1

 

更明显的例子:

复制代码
import torch

import torch.nn.functional as F

x= torch.Tensor( [ [1,2,3,4],[1,2,3,4],[1,2,3,4]])

y1= F.softmax(x, dim = 0) #对每一列进行softmax
print(y1)

y2 = F.softmax(x,dim =1) #对每一行进行softmax
print(y2)

x1 = torch.Tensor([1,2,3,4])
print(x1)

y3 = F.softmax(x1,dim=0) #一维时使用dim=0,使用dim=1报错
print(y3)
复制代码

返回:

复制代码
(deeplearning) userdeMBP:pytorch user$ python test.py 
tensor([[0.3333, 0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333, 0.3333]])
tensor([[0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439]])
tensor([1., 2., 3., 4.])
tensor([0.0321, 0.0871, 0.2369, 0.6439])
复制代码

因为列的值相同,所以按列计算时每一个所占的比重都是0.3333;行都是[1,2,3,4],所以按行计算,比重结果都为[0.0321, 0.0871, 0.2369, 0.6439]

一维使用dim=1报错:

RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

 

posted @   慢行厚积  阅读(111057)  评论(0编辑  收藏  举报
编辑推荐:
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
阅读排行:
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
点击右上角即可分享
微信分享提示