在线下载
from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
# 读取训练集
train_data=torchvision.datasets.CIFAR10('../../../dataset',
train=True,
transform=None,
target_transform=None,
download=True)
# 读取测试集
test_data=torchvision.datasets.CIFAR10('../../../dataset',
train=False,
transform=None,
target_transform=None,
download=True)
读取数据集的同时进行数据扩增
- 方法:
torchvision.transforms
from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
#读取训练集
custom_transform=transforms.transforms.Compose([
transforms5.Res1ze((64,64),#缩放到指定大小64*64
transforms.Colorjitter(0.2,0.2,0.2),
#随机颜色变换
transforms.RandomRotation(5),#随机旋转
transforms.Normalize([0.485,0.456,0.406],
#对图像像素进行归一化
[8.229,0.224,0.225])])
train_data=torchvision.datasets.CIFAR10('../../../dataset',
train=True,
transform=custom_transforms,
target_transform=None,
download=False)
DataLoader加载数据
from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
#读取数据集
train_data=torchvision.datasets.CIFAR10('../../../dataset',train=True,
transform=None,
target_transform=None,
download=True)
#实现数据批量读取
# num_workers >=1 表示多进程读取数据
# win下num_workers只能设置为0,否则会报错
train_loader=torch.utils.data.DataLoader(train_data,
batch_size=2,
shuffle=True,
num_workers=4)