数据集
ImageNet
ImageNet 是一个大规模的视觉数据库,广泛用于计算机视觉研究领域。它由斯坦福大学的李飞飞教授及其团队创建和维护。
官网需使用教育邮箱注册才能下载数据集。
可从 Kaggle ImageNet Object Localization Challenge 下载:
nohup kaggle competitions download -c imagenet-object-localization-challenge &
AutoDL 提供了 ImageNet100 和 ImageNet 数据集:
/root/autodl-pub/ImageNet100
/root/autodl-pub/ImageNet
CIFAR-10
CIFAR-10 and CIFAR-100 datasets
AutoDL 提供了 CIFAR-10 和 CIFAR-100 的数据集:
/root/autodl-pub/cifar-100
/root/autodl-pub/cifar-10
MNIST
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
VOC
The PASCAL Visual Object Classes Homepage
下载方法:在主页中找到竞赛列表 The VOC20xx Challenge
,进去在 Development Kit
节找到 training/validation data
并下载。
AutoDL 提供了 VOC2012 和 VOC2007 的数据集:
- VOC2012:
/root/autodl-pub/VOCdevkit/VOC2012.tar.gz
- VOCC2007:
/root/autodl-pub/VOCdevkit/VOC2007.tar.gz
使用 Torchvision 自带数据集
Torchvision 已经预先支持了一些数据集:Datasets — Torchvision main documentation
CIFAR-10
# 读取训练集
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, transform=None, target_transform=None, download=True)
# 读取测试集
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=None, target_transform=None, download=True)
dataset_dir
:存放数据集的路径。train
(可选):如果为True
,则构建训练集,否则构建测试集。transform
:定义数据预处理,数据增强方案都是在这里指定。target_transform
:标注的预处理,分类任务不常用。download
:是否下载,若为True
则从互联网下载,如果已经在dataset_dir
下存在,就不会再次下载
数据增强:在 transform
中指定参数
custom_transform = transforms.transforms.Compose([
transforms.Resize((64, 64)), # 缩放到指定大小(64*64)
transforms.ColorJitter(0.2, 0.2, 0.2), # 随机颜色变换
transforms.RandomRotation(5), # 随机旋转
transforms.Normalize([0.485,0.456,0.406], # 对图像像素进行归一化
[0.229,0.224,0.225])])
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, transform=custom_transforms, target_transform=None, download=False)
使用 DataLoader:
# 实现数据批量读取
train_loader = torch.utils.data.DataLoader(train_data, batch_size=2, shuffle=True, num_workers=4)
batch_size
:设置批次大小shuffle
:在装载过程中随机乱序num_workers
:>=1
表示多进程读取数据,在 Windows 下num_workers
只能设置为0
,否则会报错。
MNIST
# 训练集
train_set = mnist.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
# 测试集
test_set = mnist.MNIST('./data', train=False, transform=transforms.ToTensor(), download=True)
# 训练集载入器
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
# 测试集载入器
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 可视化数据
import random
for i in range(4):
ax = plt.subplot(2, 2, i+1)
idx = random.randint(0, len(train_set))
digit_0 = train_set[idx][0].numpy()
digit_0_image = digit_0.reshape(28, 28)
ax.imshow(digit_0_image, interpolation="nearest")
ax.set_title('label: {}'.format(train_set[idx][1]), fontsize=10, color='black')
plt.show()