pytorch读入图片并显示np.transpose(np_image, [1, 2, 0])
np.transpose(np_image, [1, 2, 0])
pytorch中读入图片并进行显示时
# visualization of an example of training data def show_image(tensor_image): np_image = tensor_image.numpy() np_image = np.transpose(np_image, [1, 2, 0])*0.5 + 0.5 # 转置后做逆归一化 plt.imshow(np_image)
plt.show() X = iter(train_loader).next()[0] print(X.size()) show_image(X)
其中有一行命令用来转置
np.transpose(np_image, [1, 2, 0])
主要是Pytorch中使用的数据格式与plt.imshow()函数的格式不一致
Pytorch中为[Channels, H, W]
而plt.imshow()中则是[H, W, Channels]
因此,要先转置一下。
该函数的解释见:plt.imshow()
pytorch读入并显示图片的方法
方式一
将读取出来的torch.FloatTensor转换为numpy
np_image = tensor_image.numpy()
np_image = np.transpose(np_image, [1, 2, 0])
plt.show()
方式二
利用torchvision中的功能函数,一般用于批量显示图片。
img=torchvision.utils.make_grid(img).numpy() plt.imshow(np.transpose(img,(1,2,0))) plt.show()
快去成为你想要的样子!