torch 中各种图像格式转化
- PIL:使用python自带图像处理库读取出来的图片格式
- numpy:使用python-opencv库读取出来的图片格式
- tensor:pytorch中训练时所采取的向量格式(当然也可以说图片)
a = torch.randn(3,496,740) c = a.numpy() print('c',c.shape) d =c.transpose(1,2,0) print('d',d.shape) e = Image.fromarray(np.uint8(d)) print('e',e.size) b = transforms.ToPILImage()(torch.squeeze(a.data.cpu(), 0)) print('b',b.size)#(740,496) print(b) c (3, 496, 740) d (496, 740, 3) e (740, 496) b (740, 496) <PIL.Image.Image image mode=RGB size=740x496 at 0x7F2A3D868208> torchvision.transforms.ToPILImage 对于一个Tensor的转化过程是: 将张量的每个元素乘上255 将张量的数据类型有FloatTensor转化成Uint8 将张量转化成numpy的ndarray类型 对ndarray对象做permute (1, 2, 0)的操作 利用Image下的fromarray函数,将ndarray对象转化成PILImage形式 输出PILImage
import torch from PIL import Image import matplotlib.pyplot as plt # loader使用torchvision中自带的transforms函数 loader = transforms.Compose([ transforms.ToTensor()]) unloader = transforms.ToPILImage() # 输入图片地址 # 返回tensor变量 def image_loader(image_name): image = Image.open(image_name).convert('RGB') image = loader(image).unsqueeze(0)#用来满足网络的输入维度的假batch维度,即不足之处补0 return image.to(device, torch.float) # 输入PIL格式图片 # 返回tensor变量 def PIL_to_tensor(image): image = loader(image).unsqueeze(0) return image.to(device, torch.float) # 输入tensor变量 # 输出PIL格式图片 def tensor_to_PIL(tensor): image = tensor.cpu().clone() image = image.squeeze(0)#移除假batch维度,即删掉上面添加的0 image = unloader(image) return image #直接展示tensor格式图片 def imshow(tensor, title=None): image = tensor.cpu().clone() # we clone the tensor to not do changes on it image = image.squeeze(0) # remove the fake batch dimension image = unloader(image) plt.imshow(image) if title is not None: plt.title(title) plt.pause(0.001) # pause a bit so that plots are updated #直接保存tensor格式图片 def save_image(tensor, **para): dir = 'results' image = tensor.cpu().clone() # we clone the tensor to not do changes on it image = image.squeeze(0) # remove the fake batch dimension image = unloader(image) if not osp.exists(dir): os.makedirs(dir) image.save('results_{}/s{}-c{}-l{}-e{}-sl{:4f}-cl{:4f}.jpg' .format(num, para['style_weight'], para['content_weight'], para['lr'], para['epoch'], para['style_loss'], para['content_loss']))