pytorch collate_fn测试用例
collate_fn
函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn
,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class MyDataset(Dataset):
def __init__(self, imgs, labels):
self.imgs = imgs
self.labels = labels
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = self.imgs[idx]
out_img = img.astype(np.float32)
out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c -->> c,h,w
out_label = self.labels[idx] #[4, 5] or [2, 5]
return out_img, out_label
#if batchsize=3
#batch is list, [3]
#batch0 tuple2 (np[3, 300, 150], np[4, 5])
#batch1 tuple2 (np[3, 300, 150], np[2, 5])
#batch2 tuple2 (np[3, 300, 150], np[4, 5])
def my_collate_fn(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on
0 dim
"""
targets = []
imgs = []
for sample in batch:
imgs.append(torch.FloatTensor(sample[0]))
targets.append(torch.FloatTensor(sample[1]))
imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]
return imgs_out, targets
img_data = []
label_data = []
nums = 34
H=300
W=150
for _ in range(nums):
random_img = np.random.randint(low=0, high=255, size=(H, W, 3))
nums_target = np.random.randint(low=0, high=10)
random_xyxy_label = np.random.random((nums_target, 5))
img_data.append(random_img)
label_data.append(random_xyxy_label)
dataset = MyDataset(img_data, label_data)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)
for cnt, (img, label) in enumerate(dataloader):
print("==>>", cnt, ", img shape=", img.shape)
for i in range(len(label)):
print("label shape=", label[i].shape)
打印如下:
==>> 0 , img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([8, 5])
label shape= torch.Size([2, 5])
label shape= torch.Size([5, 5])
==>> 1 , img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([3, 5])
label shape= torch.Size([8, 5])
label shape= torch.Size([5, 5])
==>> 2 , img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([7, 5])
label shape= torch.Size([1, 5])
label shape= torch.Size([8, 5])
好记性不如烂键盘---点滴、积累、进步!