Pytorch collate_fn用法

By default, Dataloader use collate_fn method to pack a series of images and target as tensors (first dimension of tensor is batch size). The default collate_fn expects all the images in a batch to have the same size because it uses torch.stack() to pack the images. If the images provided by Dataset have variable size, you have to provide your custom collate_fn. A simple example is shown below:

 1 # a simple custom collate function, just to show the idea
 2 
 3 # `batch` is a list of tuple where first element is image tensor and
 4 
 5 # second element is corresponding label
 6 
 7 def my_collate(batch):
 8     data = [item[0] for item in batch]  # just form a list of tensor
 9 
10     target = [item[1] for item in batch]
11     target = torch.LongTensor(target)
12     return [data, target]

Reference:   Writing Your Own Custom Dataset for Classification in PyTorch

 

 

By default, torch stacks the input image to from a tensor of size N*C*H*W, so every image in the batch must have the same height and width. In order to load a batch with variable size input image, we have to use our own collate_fn which is used to pack a batch of images.

For image classification, the input to collate_fn is a list of with size batch_size. Each element is a tuple where the first element is the input image(a torch.FloatTensor) and the second element is the image label which is simply an int. Because the samples in a batch have different size, we can store these samples in a list ans store the corresponding labels in torch.LongTensor. Then we put the image list and the label tensor into a list and return the result.

here is a very simple snippet to demonstrate how to write a custom collate_fn:

 1 import torch
 2 from torch.utils.data import DataLoader
 3 from torchvision import transforms
 4 import torchvision.datasets as datasets
 5 import matplotlib.pyplot as plt
 6 
 7 # a simple custom collate function, just to show the idea
 8 def my_collate(batch):
 9     data = [item[0] for item in batch]
10     target = [item[1] for item in batch]
11     target = torch.LongTensor(target)
12     return [data, target]
13 
14 
15 def show_image_batch(img_list, title=None):
16     num = len(img_list)
17     fig = plt.figure()
18     for i in range(num):
19         ax = fig.add_subplot(1, num, i+1)
20         ax.imshow(img_list[i].numpy().transpose([1,2,0]))
21         ax.set_title(title[i])
22 
23     plt.show()
24 
25 #  do not do randomCrop to show that the custom collate_fn can handle images of different size
26 train_transforms = transforms.Compose([transforms.Scale(size = 224),
27                                        transforms.ToTensor(),
28                                        ])
29 
30 # change root to valid dir in your system, see ImageFolder documentation for more info
31 train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
32                                      transform=train_transforms)
33 
34 trainset = DataLoader(dataset=train_dataset,
35                       batch_size=4,
36                       shuffle=True,
37                       collate_fn=my_collate, # use custom collate function here
38                       pin_memory=True)
39 
40 trainiter = iter(trainset)
41 imgs, labels = trainiter.next()
42 
43 # print(type(imgs), type(labels))
44 show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])

Reference:    How to create a dataloader with variable-size input

 

 

Dataloader的测试用例:

 1 import torch
 2 import torch.utils.data as Data
 3 import numpy as np
 4 
 5 test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])
 6 
 7 inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
 8 target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
 9 
10 torch_dataset = Data.TensorDataset(inputing,target)
11 batch = 3
12 
13 loader = Data.DataLoader(
14     dataset=torch_dataset,
15     batch_size=batch, # 批大小
16     # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
17     collate_fn=lambda x:(
18         torch.cat(
19             [x[i][j].unsqueeze(0) for i in range(len(x))], 0
20             ).unsqueeze(0) for j in range(len(x[0]))
21         )
22     )
23 
24 for (i,j) in loader:
25     print(i)
26     print(j)

Reference: DataLoader的collate_fn参数

 

pytorch 读取变长数据

https://zhuanlan.zhihu.com/p/60129684

posted @ 2019-06-08 12:00  三年一梦  阅读(10953)  评论(0编辑  收藏  举报