pytorch读入图片并显示np.transpose(np_image, [1, 2, 0])

np.transpose(np_image, [1, 2, 0])

pytorch中读入图片并进行显示时

# visualization of an example of training data
def show_image(tensor_image):
    np_image = tensor_image.numpy()
    np_image = np.transpose(np_image, [1, 2, 0])*0.5 + 0.5 # 转置后做逆归一化
    plt.imshow(np_image)
plt.show()   X
= iter(train_loader).next()[0] print(X.size()) show_image(X)

其中有一行命令用来转置

np.transpose(np_image, [1, 2, 0])

主要是Pytorch中使用的数据格式与plt.imshow()函数的格式不一致

Pytorch中为[Channels, H, W]

而plt.imshow()中则是[H, W, Channels]

因此,要先转置一下。

该函数的解释见:plt.imshow()

pytorch读入并显示图片的方法

方式一

将读取出来的torch.FloatTensor转换为numpy

np_image = tensor_image.numpy()
np_image = np.transpose(np_image, [1, 2, 0])
plt.show()

方式二

利用torchvision中的功能函数,一般用于批量显示图片。

img=torchvision.utils.make_grid(img).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

 

posted on 2020-05-24 11:42  那抹阳光1994  阅读(5125)  评论(0编辑  收藏  举报

导航