(predicted == labels).sum().item()作用

 ⚠️(predicted == labels).sum().item()作用,举个小例子介绍:

复制代码
# -*- coding: utf-8 -*-
import torch import numpy as np data1 = np.array([ [1,2,3], [2,3,4] ]) data1_torch = torch.from_numpy(data1) data2 = np.array([ [1,2,3], [2,3,4] ]) data2_torch = torch.from_numpy(data2) p = (data1_torch == data2_torch) #对比后相同的值会为1,不同则会为0 print p print type(p) d1 = p.sum() #将所有的值相加,得到的仍是tensor类别的int值 print d1 print type(d1) d2 = d1.item() #转成python数字 print d2 print type(d2)
复制代码

返回:

(deeplearning2) userdeMBP:pytorch user$ python test.py
tensor([[1, 1, 1],
        [1, 1, 1]], dtype=torch.uint8)
<class 'torch.Tensor'>
tensor(6)
<class 'torch.Tensor'>
6
<type 'int'>

 

即如果有不同的话,会变成:

复制代码
# -*- coding: utf-8 -*-
import torch import numpy
as np data1 = np.array([ [1,2,3], [2,3,4] ]) data1_torch = torch.from_numpy(data1) data2 = np.array([ [1,2,3], [4,5,6] ]) data2_torch = torch.from_numpy(data2) p = (data1_torch == data2_torch) print p print type(p) d1 = p.sum() print d1 print type(d1) d2 = d1.item() print d2 print type(d2)
复制代码

返回:

(deeplearning2) userdeMBP:pytorch user$ python test.py
tensor([[1, 1, 1],
        [0, 0, 0]], dtype=torch.uint8)
<class 'torch.Tensor'>
tensor(3)
<class 'torch.Tensor'>
3
<type 'int'>

 

posted @   慢行厚积  阅读(8120)  评论(2编辑  收藏  举报
编辑推荐:
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
阅读排行:
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
点击右上角即可分享
微信分享提示