clothing1m数据集使用

简介

Clothing1M 包含 14 个类别的 100 万张服装图像。这是一个带有噪声标签的数据集,因为数据是从多个在线购物网站收集的,并且包含许多错误标记的样本。该数据集还分别包含 50k、14k 和 10k 张带有干净标签的图像,用于训练、验证和测试。

下载地址:https://github.com/Newbeeer/L_DMI/issues/8

Dataset & DataLoader

数据集目录结构 ~/data/clothing1m:

└─images
    ├─0
    │  ├─00
    │  ├─...
    │  └─99
    ├─...
    └─9
        ├─00
        ├─...
        └─99
└─category_names_chn.txt
└─category_names_eng.txt
└─...
import os

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms

# mode=0: noisy train set, mode=1: clean val set, mode=2: clean test set
class Clothing1m(Dataset):
    nb_classes = 14

    def __init__(self, mode=0, root='~/data/clothing1m', transform=None):
        root = os.path.expanduser(root)
        self.mode = mode
        self.root = root
        self.transform = transform
        if mode == 0:
            txt_file = 'noisy_label_kv.txt'
        else:
            txt_file = 'clean_label_kv.txt'
        with open(os.path.join(root, txt_file), 'r') as f:
            lines = f.read().splitlines()

        self.labels = {line.split()[0]: int(line.split()[1]) for line in lines}

        data_path = []
        txt_file = ['noisy_train_key_list.txt', 'clean_val_key_list.txt', 'clean_test_key_list.txt']
        if mode in [0, 1, 2]:
            with open(os.path.join(root, txt_file[mode]), 'r') as f:
                lines = f.read().splitlines()
                for line in lines:
                    data_path.append(line)
        else:
            raise ValueError('mode should be 0, 1 or 2')

        self.data = data_path
        self.targets = [self.labels[img_path] for img_path in data_path]

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img_path = self.data[index]
        targets = self.labels[img_path]
        image = Image.open(os.path.join(self.root, img_path)).convert('RGB')
        image = self.transform(image)
        if self.mode == 0:
            return image, targets, index
        return image, targets


class Clothing1mDataloader:
    def __init__(self, batch_size=64, num_workers=8, root='~/data/clothing1m'):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.root = root

        self.transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
        ])
        self.transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
        ])

    def train(self):
        transform = self.transform_train
        dataset = Clothing1m(mode=0, root=self.root, transform=transform)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True,
                                num_workers=self.num_workers, pin_memory=True)
        return dataloader

    def val(self):
        dataset = Clothing1m(mode=1, root=self.root, transform=self.transform_test)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False,
                                num_workers=self.num_workers, pin_memory=True)
        return dataloader

    def test(self):
        dataset = Clothing1m(mode=2, root=self.root, transform=self.transform_test)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False,
                                num_workers=self.num_workers, pin_memory=True)
        return dataloader

依赖

torch                              2.3.1
posted @ 2024-10-01 16:40  October-  阅读(61)  评论(0编辑  收藏  举报