pytorch collate_fn测试用例

collate_fn 函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MyDataset(Dataset):
    def __init__(self, imgs, labels):
        self.imgs = imgs
        self.labels = labels

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        out_img = img.astype(np.float32)
        out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c  -->>  c,h,w
        out_label = self.labels[idx] #[4, 5] or [2, 5]
        return out_img, out_label

#if batchsize=3
#batch is list, [3]
#batch0 tuple2  (np[3, 300, 150], np[4, 5])
#batch1 tuple2  (np[3, 300, 150], np[2, 5])
#batch2 tuple2  (np[3, 300, 150], np[4, 5])
def my_collate_fn(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on
                                 0 dim
    """
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(torch.FloatTensor(sample[0]))
        targets.append(torch.FloatTensor(sample[1]))

    imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]
    return imgs_out, targets




img_data = []
label_data = []

nums = 34
H=300
W=150
for _ in range(nums):
    random_img = np.random.randint(low=0, high=255, size=(H, W, 3))
    nums_target = np.random.randint(low=0, high=10)
    random_xyxy_label = np.random.random((nums_target, 5))
    img_data.append(random_img)
    label_data.append(random_xyxy_label)

dataset = MyDataset(img_data, label_data)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)

for cnt, (img, label) in enumerate(dataloader):
    print("==>>", cnt, ",  img shape=", img.shape)
    for i in range(len(label)):
        print("label shape=", label[i].shape)

打印如下:

==>> 0 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([8, 5])
label shape= torch.Size([2, 5])
label shape= torch.Size([5, 5])
==>> 1 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([3, 5])
label shape= torch.Size([8, 5])
label shape= torch.Size([5, 5])
==>> 2 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([7, 5])
label shape= torch.Size([1, 5])
label shape= torch.Size([8, 5])
posted @ 2023-11-01 14:10  无左无右  阅读(83)  评论(0编辑  收藏  举报