torch topk函数
这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。
用法
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
input:一个tensor数据
k:指明是得到前k个数据以及其index
dim: 指定在哪个维度上排序, 默认是最后一个维度
largest:如果为True,按照大到小排序; 如果为False,按照小到大排序
sorted:返回的结果按照顺序返回
out:可缺省,不要
比如,三行两列,3个样本,2个类别。
import torch pred = torch.randn((4, 5)) print(pred) values, indices = pred.topk(1, dim=1, largest=True, sorted=True) print(indices) # 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。 _, indices_max = pred.max(dim=1, keepdim=True) print(indices_max == indices) # pred tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263], [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156], [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185], [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]]) # indices, shape为 【4,1】, tensor([[3], #【0,0】代表 第一个样本最可能属于第一类别 [2], # 【1, 0】代表第二个样本最可能属于第二类别 [4], [2]]) # indices_max等于indices tensor([[True], [True], [True], [True]])
分类:
pytorch学习
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
· 为什么 退出登录 或 修改密码 无法使 token 失效
2020-03-24 matplotlib、PIL、cv2图像操作差异分析
2020-03-24 Unity3D鼠标控制摄像机“左右移动控制视角+WASD键盘控制前后左右+空格键抬升高度”脚本
2020-03-24 unity3d视角跟随鼠标左右上下转动
2020-03-24 如何利用几个面做一个天空盒子