matplotlib展示预测图片结果,按Enter展示下一批
虽然 images,labels = next(iter(test_data)) 可以每次1batch获取,但是超级慢,不推荐。
import torch import platform if platform.system() == 'Windows': import matplotlib matplotlib.use('TkAgg') import matplotlib.pyplot as plt elif platform.system() == 'Linux': import matplotlib.pyplot as plt #.pt的加载方式 model = torch.load(r'best_001_1.000000_model.pt') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) #模型转入GPU model.eval() #测试模式 # 查看前10batch的预测效果图 plt.figure(figsize=(12, 4)) # 创建窗口并指定大小(英寸) for index,(images,labels) in enumerate(test_data): images,labels = images.cuda(),labels.cuda() #样本转入GPU outputs = model(images) #预测1个batch _, predicted_labels = torch.max(outputs.data, 1) #显示1个batch的图像,title:预测标签-真实标签,预测错误title为红色 for i in range(batch): images_cpu = images.cpu() #GPU转到CPU来进行显示 img = images_cpu[i].squeeze(0).permute(1, 2, 0) #NCHW变为CHW再变为HWC img = (img+1)/2 #[-1,1]变为[0,1]。[0,1]或[0,255]都可以显示 plt.subplot(1,4,i+1) #1行4列来排布 plt.title('{}-{}'.format(classes[predicted_labels[i].item()], classes[labels[i].item()]), color='red' if classes[predicted_labels[i].item()]!=classes[labels[i].item()] else 'black')#1个batch里的图对应的标签为标题 plt.imshow(img) #显示图像 plt.axis('off') #不显示坐标轴 # plt.show() #显示窗体并阻塞 plt.ion() #显示窗口,不阻塞 plt.pause(0.1) #阻塞0.1s,其实是让窗口一直驻留显示 input('Press Enter to Continue {}'.format(index+1)) if index==9: plt.close() break