matplotlib展示预测图片结果,按Enter展示下一批

虽然  images,labels = next(iter(test_data))  可以每次1batch获取,但是超级慢,不推荐。

import torch
import platform
if platform.system() == 'Windows':
    import matplotlib
    matplotlib.use('TkAgg')
    import matplotlib.pyplot as plt
elif platform.system() == 'Linux':
    import matplotlib.pyplot as plt
   
#.pt的加载方式
model = torch.load(r'best_001_1.000000_model.pt')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)    #模型转入GPU
model.eval()        #测试模式

# 查看前10batch的预测效果图
plt.figure(figsize=(12, 4))  # 创建窗口并指定大小(英寸)
for index,(images,labels) in enumerate(test_data):
    images,labels = images.cuda(),labels.cuda()  #样本转入GPU
    outputs = model(images) #预测1个batch
    _, predicted_labels = torch.max(outputs.data, 1)

    #显示1个batch的图像,title:预测标签-真实标签,预测错误title为红色
    for i in range(batch):
        images_cpu = images.cpu() #GPU转到CPU来进行显示
        img = images_cpu[i].squeeze(0).permute(1, 2, 0) #NCHW变为CHW再变为HWC
        img = (img+1)/2 #[-1,1]变为[0,1]。[0,1]或[0,255]都可以显示
        plt.subplot(1,4,i+1) #1行4列来排布
        plt.title('{}-{}'.format(classes[predicted_labels[i].item()], classes[labels[i].item()]), color='red' if classes[predicted_labels[i].item()]!=classes[labels[i].item()] else 'black')#1个batch里的图对应的标签为标题
        plt.imshow(img) #显示图像
        plt.axis('off') #不显示坐标轴
    # plt.show() #显示窗体并阻塞
    plt.ion() #显示窗口,不阻塞
    plt.pause(0.1) #阻塞0.1s,其实是让窗口一直驻留显示
    input('Press Enter to Continue {}'.format(index+1))
    if index==9:
        plt.close()
        break

 

posted @ 2024-11-01 16:16  夕西行  阅读(6)  评论(0编辑  收藏  举报