[PyTorch入门]之数据导入与处理
数据导入与处理
来自这里。
在解决任何机器学习问题时,都需要在处理数据上花费大量的努力。PyTorch提供了很多工具来简化数据加载,希望使代码更具可读性。在本教程中,我们将学习如何从繁琐的数据中加载、预处理数据或增强数据。
开始本教程之前,请确认你已安装如下Python包:
- scikit-image:图像IO操作和格式转换
- pandas:更方便解析CSV
我们接下来要处理的数据集是人脸姿态。这意味着人脸的注释如下:
总之,每个面部都有68个不同标记点。
可以从这里下载数据集,并将其解压后存放到目录‘data/faces/’。
数据集来自带有面部注释的CSV文件,文件内容类似以下格式:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y 0805personali01.jpg,27,83,27,98, ... 84,134 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
接下来我们快速读取CSV文件,并从(N,2)数组中获取注释,N表示标记数量。
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv') n = 65 img_name = landmarks_frame.iloc[n, 0] landmarks = landmarks_frame.iloc[n, 1:].as_matrix() landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name)) print('Landmarks shape: {}'.format(landmarks.shape)) print('First 4 Landmarks: {}'.format(landmarks[:4]))
输出:
Image name: person-7.jpg Landmarks shape: (68, 2) First 4 Landmarks: [[32. 65.] [33. 76.] [34. 86.] [34. 97.]]
现在我们写一个简单的帮助函数:展示图片和它的标记,用它来展示样本。
def show_landmarks(image,landmarks): ''' 展示带标记点的图像 ''' plt.imshow(image) plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r') plt.pause(10) plt.figure() show_landmarks(io.imread(os.path.join('data/faces',img_name)),landmarks) plt.show()
数据集类(Dataset class)
torch.utils.data.Dataset
是一个表示数据集的抽象类。你自定义的数据集应该继承Dataset
并重写以下方法:
- len 这样
len(dataset)
是可以返回数据集的大小 - getitem 支持索引操作,比如
dataset[i]
来获取第i个样本。
现在我们来实现我们的面部标记数据集类。我们将在__init__
中读取CSV,然后再__getitem__
中读取图像。这样可以高效利用内存,因为所有的图像并不是都存在在内存中,而是按需读取。
我们数据集的样本是字典格式的:{'image':image,'landmarks':landmarks}
。我们的数据集将采用可选参数transform
,以便任何必要的处理都可以被应用在样本上。在下一节中我们会看到transform
的用途。
class FaceLandmarksDataset(Dataset): ''' Face Landmarks Dataset ''' def __init__(self,csv_file,root_dir,transform=None): ''' param csv_file(string): 带注释的CSV文件路径 param root_dit(string): 存储图像的路径 param transform(callable,optional): 被应用到样本的可选transform操作 ''' self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self,idx): img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0]) image = io.imread(img_name) landmarks = self.landmarks_frame.iloc[idx,1:] landmarks = np.array([landmarks]) landmarks = landmarks.astype('float').reshape(-1,2) sample = {'image':image,'landmarks':landmarks} if self.transform: sample = self.transform(sample) return sample
现在我们实例化这个类,并且迭代输出部分样本。我们打印输出前4个样本并展示它们的标记。
face_dataset = FaceLandmarksDataset( csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/') fig = plt.figure() for i in range(len(face_dataset)): sample = face_dataset[i] print(i, sample['image'].shape, sample['landmarks'].shape) ax = plt.subplot(1, 4, i+1) plt.tight_layout() ax.set_title('Sample #{}'.format(i)) ax.axis('off') show_landmarks(**sample) if i == 3: plt.show() break
输出:
0 (324, 215, 3) (68, 2) 1 (500, 333, 3) (68, 2) 2 (250, 258, 3) (68, 2) 3 (434, 290, 3) (68, 2)
Transforms(转换)
从上面的例子可以看出这些样本的尺寸并不一致。大多数神经网络都期望图像的尺寸是固定的。这样的话,我们就需要一些处理代码。接下来我们创建三个变换函数:
Rescale
:缩放图像RandomCrop
:随机裁剪图像。这是数据扩充。ToTensor
:将numpy图像转为torch图像(我们需要交换轴)。
我们将以类而不是简单的函数的方式来实现它们,这样就不需要在每次调用时都传递转换需要的参数。这样我们只需要实现__call__
方法,需要的话还可以实现__init__
方法。然后我们可以按如下的方式使用:
tsfm = Transform(params) transformed_sample = tsfm(sample)
下面展示如何将这些转换同时应用在图像和标记点。
class Rescale(object): ''' 按给定的尺寸缩放图像 param output_size (tuple or int): 目标输出尺寸。如果是tuple,输出为匹配的输出尺寸;如果是int,则匹配较小的图像边缘,保证相同的长宽比例。 ''' def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] if isinstance(self.output_size, int): if h > w: new_h, new_w = self.output_size*h/w, self.output_size else: new_h, new_w = self.output_size, self.output_size*w/h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) landmarks = landmarks*[new_w/w, new_h/h] return {'image': img, 'landmarks': landmarks} class RandomCrop(object): ''' 随机裁剪图像 param output_size (tuple or int): 目标输出尺寸。如果是int,正方形裁剪 ''' def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] new_h, new_w = self.output_size top = np.random.randint(0, h-new_h) left = np.random.randint(0, w-new_w) image = image[top:top+new_h, left:left+new_w] landmarks = landmarks - [left, top] return {'image': image, 'landmarks': landmarks} class ToTensor(object): ''' 将ndarrays格式样本转换为Tensors ''' def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] image = image.transpose((2, 0, 1)) return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
组合变换
现在,我们在样本上应用转换。
比如我们想将图片的短边缩放为256然后在随机裁剪出一个224的正方形,那么我们将用到Rescale
和RandomCrop
。torchvision.transforms.Compost
可以帮助我们完成上述组合操作。
scale = Rescale(256) crop = RandomCrop(128) composed = transforms.Compose([Rescale(256), RandomCrop(224)]) fig = plt.figure() sample = face_dataset[65] for i, tsfrm in enumerate([scale, crop, composed]): transformed_sample = tsfrm(sample) ax = plt.subplot(1, 3, i+1) plt.tight_layout() ax.set_title(type(tsfrm).__name__) show_landmarks(**transformed_sample) plt.show()
遍历数据集
接下来我们将上面的代码整合起来,创建一个带有组合变换的数据集。综上所述,每次采样该数据集时:
- 从文件中动态读取图像
- 转换应用到读取的图像上
- 由于其中一种转换是随机的,因此数据在抽样时得到了扩充
我们可以是像之前一样用for i in range
循环遍历创建的数据集:
transformed_dataset = FaceLandmarksDataset( csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/', transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]) ) for i in range(len(transformed_dataset)): sample = transformed_dataset[i] print(i, sample['image'].size(), sample['landmarks'].size()) if i == 3: break
输出:
0 torch.Size([3, 224, 224]) torch.Size([68, 2]) 1 torch.Size([3, 224, 224]) torch.Size([68, 2]) 2 torch.Size([3, 224, 224]) torch.Size([68, 2]) 3 torch.Size([3, 224, 224]) torch.Size([68, 2])
然而,只是使用建档的for
训练遍历数据,我们将丢失很多特征。尤其是我们丢失了:
- 批量处理数据
- 移动数据
- 使用
multiprocessing
并行加载数据
torch.utils.data.DataLoader
是一个提供了所有这些功能的迭代器。接下来使用的参数是明朗的。一个有趣的参数是collate_fn
。你可以使用collate_fn
指定需要如何对样本进行批量处理。然而,默认的collate足够胜任大多数使用场景。
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4) def show_landmarks_batch(sample_batched): ''' 批量展示样本 ''' images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks'] batch_size = len(sample_batched) im_size = images_batch.size(2) grid_border_size = 2 grid = utils.make_grid(images_batch) plt.imshow(grid.numpy().transpose((1, 2, 0))) for i in range(batch_size): plt.scatter( landmarks_batch[i, :, 0].numpy() + i * im_size + (i+1)*grid_border_size, landmarks_batch[i, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r' ) plt.title('Batch from dataloader') for i_batch,sample_batched in enumerate(dataloader): print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size()) if i_batch == 3: plt.figure() show_landmarks_batch(sample_batched) plt.axis('off') plt.ioff() plt.show() break
输出:
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2]) 1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2]) 2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2]) 3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
后续:torchvision
在本教程中,我们了解了如何实现并使用数据集、转换和数据导入。torchvision
包提供了一些常用的数据集和转换。你甚至可能不需要编写自定义的类。在torchvision中最常用的数据集是ImageFolder
。它假设图像的组织方式如下所示:
root/ants/xxx.png root/ants/xxy.jpeg root/ants/xxz.png . . . root/bees/123.jpg root/bees/nsdf3.png root/bees/asd932_.png
‘ants’、‘bees’等等都是类的标签。在PIL.Image
上操作的类似常用的转化,如RandomHorizontalFlip
、Scale
,都是可用的。你可以使用它们来编写想下面的数据导入:
import torch from torchvision import transforms, datasets data_transform = transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train', transform=data_transform) dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset, batch_size=4, shuffle=True, num_workers=4)
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性
· 全网最简单!3分钟用满血DeepSeek R1开发一款AI智能客服,零代码轻松接入微信、公众号、小程