torchvision-加载数据

在线下载

from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
# 读取训练集
train_data=torchvision.datasets.CIFAR10('../../../dataset',
                                                  train=True,
                                                  transform=None,
                                                  target_transform=None,
                                                  download=True)
# 读取测试集
test_data=torchvision.datasets.CIFAR10('../../../dataset',
                                                  train=False,
                                                  transform=None,
                                                  target_transform=None,
                                                  download=True)

读取数据集的同时进行数据扩增

  • 方法:torchvision.transforms
from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

#读取训练集
custom_transform=transforms.transforms.Compose([
            transforms5.Res1ze((64,64),#缩放到指定大小64*64
            transforms.Colorjitter(0.2,0.2,0.2),
            #随机颜色变换
            transforms.RandomRotation(5),#随机旋转
            transforms.Normalize([0.485,0.456,0.406],
            #对图像像素进行归一化
            [8.229,0.224,0.225])])
train_data=torchvision.datasets.CIFAR10('../../../dataset',
                                        train=True,
                                        transform=custom_transforms,
                                        target_transform=None,
                                        download=False)

DataLoader加载数据

from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
#读取数据集
train_data=torchvision.datasets.CIFAR10('../../../dataset',train=True,
                                            transform=None,
                                            target_transform=None,
                                            download=True)
#实现数据批量读取
# num_workers >=1 表示多进程读取数据
# win下num_workers只能设置为0,否则会报错
train_loader=torch.utils.data.DataLoader(train_data,
                                            batch_size=2,
                                            shuffle=True,
                                            num_workers=4)
posted @ 2022-04-06 10:11  ArdenWang  阅读(70)  评论(0编辑  收藏  举报