torch.max()函数
1:torch.max(input, dim)
函数定义:torch.max(input, dim, max=None, max_indices=None, keepdim=False) -> (Tensor, LongTensor)
作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的值和位置索引。
输入
input是softmax函数输出的一个tensor
dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
输出
函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。
a=torch.randn(3,4)
print(a)
print(a.shape)
b=torch.max(a,1)
print(b)
print(b.indices)
>tensor([[ 0.0092, -0.6736, -1.1466, -2.2001],
[-0.2323, -0.3589, 1.4158, -0.1154],
[ 0.7965, -1.3123, -2.2986, -0.8566]])
torch.Size([3, 4])
torch.return_types.max(
values=tensor([0.0092, 1.4158, 0.7965]),
indices=tensor([0, 2, 0]))
tensor([0, 2, 0])
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】