数据集

ImageNet

ImageNet 是一个大规模的视觉数据库,广泛用于计算机视觉研究领域。它由斯坦福大学的李飞飞教授及其团队创建和维护。

官网需使用教育邮箱注册才能下载数据集。

可从 Kaggle ImageNet Object Localization Challenge 下载:

nohup kaggle competitions download -c imagenet-object-localization-challenge &

AutoDL 提供了 ImageNet100 和 ImageNet 数据集:

  • /root/autodl-pub/ImageNet100
  • /root/autodl-pub/ImageNet

CIFAR-10

CIFAR-10 and CIFAR-100 datasets

AutoDL 提供了 CIFAR-10 和 CIFAR-100 的数据集:

  • /root/autodl-pub/cifar-100
  • /root/autodl-pub/cifar-10

MNIST

MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

VOC

The PASCAL Visual Object Classes Homepage

下载方法:在主页中找到竞赛列表 The VOC20xx Challenge,进去在 Development Kit 节找到 training/validation data 并下载。

AutoDL 提供了 VOC2012 和 VOC2007 的数据集:

  • VOC2012: /root/autodl-pub/VOCdevkit/VOC2012.tar.gz
  • VOCC2007: /root/autodl-pub/VOCdevkit/VOC2007.tar.gz

使用 Torchvision 自带数据集

Torchvision 已经预先支持了一些数据集:Datasets — Torchvision main documentation

CIFAR-10

# 读取训练集
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, transform=None, target_transform=None, download=True)
# 读取测试集
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=None, target_transform=None, download=True)
  • dataset_dir:存放数据集的路径。
  • train(可选):如果为 True,则构建训练集,否则构建测试集。
  • transform:定义数据预处理,数据增强方案都是在这里指定。
  • target_transform:标注的预处理,分类任务不常用。
  • download:是否下载,若为 True 则从互联网下载,如果已经在 dataset_dir 下存在,就不会再次下载

数据增强:在 transform 中指定参数

custom_transform = transforms.transforms.Compose([
              transforms.Resize((64, 64)),               # 缩放到指定大小(64*64)
              transforms.ColorJitter(0.2, 0.2, 0.2),     # 随机颜色变换
              transforms.RandomRotation(5),              # 随机旋转
              transforms.Normalize([0.485,0.456,0.406],  # 对图像像素进行归一化
                                   [0.229,0.224,0.225])])
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, transform=custom_transforms, target_transform=None, download=False)

使用 DataLoader:

# 实现数据批量读取
train_loader = torch.utils.data.DataLoader(train_data, batch_size=2, shuffle=True, num_workers=4)
  • batch_size:设置批次大小
  • shuffle:在装载过程中随机乱序
  • num_workers>=1 表示多进程读取数据,在 Windows 下 num_workers 只能设置为 0,否则会报错。

MNIST

# 训练集
train_set = mnist.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
# 测试集
test_set = mnist.MNIST('./data', train=False, transform=transforms.ToTensor(), download=True)
# 训练集载入器
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
# 测试集载入器
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 可视化数据
import random
for i in range(4):
    ax = plt.subplot(2, 2, i+1)
    idx = random.randint(0, len(train_set))
    digit_0 = train_set[idx][0].numpy()
    digit_0_image = digit_0.reshape(28, 28)
    ax.imshow(digit_0_image, interpolation="nearest")
    ax.set_title('label: {}'.format(train_set[idx][1]), fontsize=10, color='black')
plt.show()

参考:数据读取与数据扩增 | 动手学 CV - PyTorch

posted @ 2024-09-20 09:56  Undefined443  阅读(5)  评论(0编辑  收藏  举报