clothing1m数据集使用
简介
Clothing1M 包含 14 个类别的 100 万张服装图像。这是一个带有噪声标签的数据集,因为数据是从多个在线购物网站收集的,并且包含许多错误标记的样本。该数据集还分别包含 50k、14k 和 10k 张带有干净标签的图像,用于训练、验证和测试。
下载地址:https://github.com/Newbeeer/L_DMI/issues/8
Dataset & DataLoader
数据集目录结构 ~/data/clothing1m:
└─images
├─0
│ ├─00
│ ├─...
│ └─99
├─...
└─9
├─00
├─...
└─99
└─category_names_chn.txt
└─category_names_eng.txt
└─...
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
# mode=0: noisy train set, mode=1: clean val set, mode=2: clean test set
class Clothing1m(Dataset):
nb_classes = 14
def __init__(self, mode=0, root='~/data/clothing1m', transform=None):
root = os.path.expanduser(root)
self.mode = mode
self.root = root
self.transform = transform
if mode == 0:
txt_file = 'noisy_label_kv.txt'
else:
txt_file = 'clean_label_kv.txt'
with open(os.path.join(root, txt_file), 'r') as f:
lines = f.read().splitlines()
self.labels = {line.split()[0]: int(line.split()[1]) for line in lines}
data_path = []
txt_file = ['noisy_train_key_list.txt', 'clean_val_key_list.txt', 'clean_test_key_list.txt']
if mode in [0, 1, 2]:
with open(os.path.join(root, txt_file[mode]), 'r') as f:
lines = f.read().splitlines()
for line in lines:
data_path.append(line)
else:
raise ValueError('mode should be 0, 1 or 2')
self.data = data_path
self.targets = [self.labels[img_path] for img_path in data_path]
def __len__(self):
return len(self.targets)
def __getitem__(self, index):
img_path = self.data[index]
targets = self.labels[img_path]
image = Image.open(os.path.join(self.root, img_path)).convert('RGB')
image = self.transform(image)
if self.mode == 0:
return image, targets, index
return image, targets
class Clothing1mDataloader:
def __init__(self, batch_size=64, num_workers=8, root='~/data/clothing1m'):
self.batch_size = batch_size
self.num_workers = num_workers
self.root = root
self.transform_train = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
])
self.transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
])
def train(self):
transform = self.transform_train
dataset = Clothing1m(mode=0, root=self.root, transform=transform)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers, pin_memory=True)
return dataloader
def val(self):
dataset = Clothing1m(mode=1, root=self.root, transform=self.transform_test)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=True)
return dataloader
def test(self):
dataset = Clothing1m(mode=2, root=self.root, transform=self.transform_test)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=True)
return dataloader
依赖
torch 2.3.1