Pytorch collate_fn用法
By default, Dataloader
use collate_fn
method to pack a series of images and target as tensors (first dimension of tensor is batch size). The default collate_fn
expects all the images in a batch to have the same size because it uses torch.stack()
to pack the images. If the images provided by Dataset
have variable size, you have to provide your custom collate_fn
. A simple example is shown below:
1 # a simple custom collate function, just to show the idea
2
3 # `batch` is a list of tuple where first element is image tensor and
4
5 # second element is corresponding label
6
7 def my_collate(batch):
8 data = [item[0] for item in batch] # just form a list of tensor
9
10 target = [item[1] for item in batch]
11 target = torch.LongTensor(target)
12 return [data, target]
Reference: Writing Your Own Custom Dataset for Classification in PyTorch
By default, torch stacks the input image to from a tensor of size N*C*H*W
, so every image in the batch must have the same height and width. In order to load a batch with variable size input image, we have to use our own collate_fn
which is used to pack a batch of images.
For image classification, the input to collate_fn
is a list of with size batch_size
. Each element is a tuple where the first element is the input image(a torch.FloatTensor
) and the second element is the image label which is simply an int
. Because the samples in a batch have different size, we can store these samples in a list ans store the corresponding labels in torch.LongTensor
. Then we put the image list and the label tensor into a list and return the result.
here is a very simple snippet to demonstrate how to write a custom collate_fn
:
1 import torch
2 from torch.utils.data import DataLoader
3 from torchvision import transforms
4 import torchvision.datasets as datasets
5 import matplotlib.pyplot as plt
6
7 # a simple custom collate function, just to show the idea
8 def my_collate(batch):
9 data = [item[0] for item in batch]
10 target = [item[1] for item in batch]
11 target = torch.LongTensor(target)
12 return [data, target]
13
14
15 def show_image_batch(img_list, title=None):
16 num = len(img_list)
17 fig = plt.figure()
18 for i in range(num):
19 ax = fig.add_subplot(1, num, i+1)
20 ax.imshow(img_list[i].numpy().transpose([1,2,0]))
21 ax.set_title(title[i])
22
23 plt.show()
24
25 # do not do randomCrop to show that the custom collate_fn can handle images of different size
26 train_transforms = transforms.Compose([transforms.Scale(size = 224),
27 transforms.ToTensor(),
28 ])
29
30 # change root to valid dir in your system, see ImageFolder documentation for more info
31 train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
32 transform=train_transforms)
33
34 trainset = DataLoader(dataset=train_dataset,
35 batch_size=4,
36 shuffle=True,
37 collate_fn=my_collate, # use custom collate function here
38 pin_memory=True)
39
40 trainiter = iter(trainset)
41 imgs, labels = trainiter.next()
42
43 # print(type(imgs), type(labels))
44 show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])
Reference: How to create a dataloader with variable-size input
Dataloader的测试用例:
1 import torch 2 import torch.utils.data as Data 3 import numpy as np 4 5 test = np.array([0,1,2,3,4,5,6,7,8,9,10,11]) 6 7 inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)])) 8 target = torch.tensor(np.array([test[i:i + 1] for i in range(10)])) 9 10 torch_dataset = Data.TensorDataset(inputing,target) 11 batch = 3 12 13 loader = Data.DataLoader( 14 dataset=torch_dataset, 15 batch_size=batch, # 批大小 16 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少 17 collate_fn=lambda x:( 18 torch.cat( 19 [x[i][j].unsqueeze(0) for i in range(len(x))], 0 20 ).unsqueeze(0) for j in range(len(x[0])) 21 ) 22 ) 23 24 for (i,j) in loader: 25 print(i) 26 print(j)
Reference: DataLoader的collate_fn参数
pytorch 读取变长数据