pytorch 带batch的tensor类型图像显示操作
pytorch 带batch的tensor类型图像显示操作
pytorch 带batch的tensor类型图像显示操作_python_脚本之家 (jb51.net)
这篇文章主要介绍了pytorch 带batch的tensor类型图像显示操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
项目场景
pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。
那么如何显示dataloader里面带batch的tensor类型的图像呢?
显示图像
绘图最常用的库就是matplotlib:
1
|
pip install matplotlib |
显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:
数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)
用法示例如下:
1
2
3
4
5
|
>>> x = torch.randn( 2 , 3 , 5 ) >>> x.size() torch.Size([ 2 , 3 , 5 ]) >>> x.permute( 1 , 2 , 0 ).size() torch.Size([ 3 , 5 , 2 ]) |
代码示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
#%% 导入模块 import torch import matplotlib.pyplot as plt from torchvision.utils import make_grid from torch.utils.data import DataLoader from torchvision import datasets, transforms #%% 下载数据集 train_file = datasets.MNIST( root = './dataset/' , train = True , transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(( 0.1307 ,), ( 0.3081 ,)) ]), download = True ) #%% 制作数据加载器 train_loader = DataLoader( dataset = train_file, batch_size = 9 , shuffle = True ) #%% 训练数据可视化 images, labels = next ( iter (train_loader)) print (images.size()) # torch.Size([9, 1, 28, 28]) plt.figure(figsize = ( 9 , 9 )) for i in range ( 9 ): plt.subplot( 3 , 3 , i + 1 ) plt.title(labels[i].item()) plt.imshow(images[i].permute( 1 , 2 , 0 ), cmap = 'gray' ) plt.axis( 'off' ) plt.show() |
这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。
所以,如果你想查看训练集的原始图像,还得反标准化。
标准化:image = (image-mean)/std
反标准化:image = image*std+mean
我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:
最终效果
补充:PIL,plt显示tensor类型的图像
该方法针对显示Dataloader读取的图像
PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
# 方法1:Image.show() # transforms.ToPILImage()中有一句 # npimg = np.transpose(pic.numpy(), (1, 2, 0)) # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维 img = transforms.ToPILImage(image[ 0 ]) img.show() # 方法2:plt.imshow(ndarray) img = image[ 0 ] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维 img = img.numpy() # FloatTensor转为ndarray img = np.transpose(img, ( 1 , 2 , 0 )) # 把channel那一维放到最后 # 显示图片 plt.imshow(img) plt.show() cnt + = 1 |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!