Pytorch:将图像tensor数据用Opencv显示
Pytorch:将图像tensor数据用Opencv显示
Pytorch:将图像tensor数据用Opencv显示 - 知乎 (zhihu.com)
将图像tensor数据用Opencv显示
首先导入相关库:*
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
利用PIL中的Image打开一张图片
image2=Image.open('pikachu.jpg')
这里print看一下image2的图像数据类型,这里可以直接调用image2.show()直接显示:
print(image2)
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=814x982 at 0x1E988BA74A8>
将image2转化为tensor数据(为什么转化为tensor,当然是为了方便计算)
transform2=transforms.Compose([transforms.ToTensor()])
tensor2=transform2(image2)
print('tensor2:',tensor2)#打印看一下tensor的数据
print(tensor2.dtype)#torch.float32
print(tensor2.shape)#返回tensor2_shape torch.Size([3, 982, 814])->3通道982*814的RGB图像
tensor2: tensor([[[0.5647, 0.5686, 0.5686, ..., 0.5725, 0.5725, 0.5686],
[0.5647, 0.5686, 0.5725, ..., 0.5725, 0.5725, 0.5686],
[0.5686, 0.5725, 0.5725, ..., 0.5725, 0.5725, 0.5686],
...,
[0.3176, 0.3216, 0.3216, ..., 0.3098, 0.3098, 0.3098],
[0.3176, 0.3216, 0.3216, ..., 0.3098, 0.3098, 0.3098],
[0.3176, 0.3216, 0.3216, ..., 0.3137, 0.3137, 0.3137]],
[[0.7412, 0.7451, 0.7451, ..., 0.7490, 0.7490, 0.7451],
[0.7412, 0.7451, 0.7490, ..., 0.7490, 0.7490, 0.7451],
[0.7451, 0.7490, 0.7490, ..., 0.7490, 0.7490, 0.7451],
...,
[0.5529, 0.5569, 0.5569, ..., 0.5451, 0.5451, 0.5451],
[0.5529, 0.5569, 0.5569, ..., 0.5451, 0.5451, 0.5451],
[0.5529, 0.5569, 0.5569, ..., 0.5490, 0.5490, 0.5490]],
[[0.9059, 0.9098, 0.9098, ..., 0.9098, 0.9098, 0.9059],
[0.9059, 0.9098, 0.9137, ..., 0.9098, 0.9098, 0.9059],
[0.9098, 0.9137, 0.9137, ..., 0.9098, 0.9098, 0.9059],
...,
[0.8275, 0.8314, 0.8314, ..., 0.8275, 0.8275, 0.8275],
[0.8275, 0.8314, 0.8314, ..., 0.8275, 0.8275, 0.8275],
[0.8275, 0.8314, 0.8314, ..., 0.8314, 0.8314, 0.8314]]])
要将tensor图像数据转为opencv支持的图像数据,首先要了解opencv所支持的图像数据:
image3=cv2.imread('pokeman/pikachu/00000000.jpg')
print(image3)
[[[231 189 144]
[232 190 145]
[232 190 145]
...
[232 191 146]
[232 191 146]
[231 190 145]]]
print(image3.shape)
(982, 814, 3)
print(type(image3))
<class 'numpy.ndarray'>
print(image3.dtype)
uint8
所以我们知道opencv支持的图像数据时numpy格式,数据类型为uint8,而且像素值分布在[0,255]之间。 但是从上面的tensor数据可以看出,像素值并不是分布在[0,255],且数据类型为float32,所以需要做一下normalize和数据变换,将图像数据扩展到[0,255]。还有一点不同的是tensor(3,982, 814)、numpy(982, 814, 3)存储的数据维度顺序不同。
array1=tensor2.numpy()#将tensor数据转为numpy数据
maxValue=array1.max()
array1=array1*255/maxValue#normalize,将图像数据扩展到[0,255]
mat=np.uint8(array1)#float32-->uint8
print('mat_shape:',mat.shape)#mat_shape: (3, 982, 814)
mat=mat.transpose(1,2,0)#mat_shape: (982, 814,3)
cv2.imshow("img",mat)
cv2.waitKey()
这是由于opencv中的颜色通道顺序是BGR而PIL、torch里面的图像颜色通道是RGB,利用cvtColor对颜色通道进行转换
mat=cv2.cvtColor(mat,cv2.COLOR_BGR2RGB)
cv2.imshow("img",mat)
cv2.waitKey()