Pytorch笔记|小土堆|P14-15|torchvision数据集使用、Dataloader使用

学会看内置数据集的官方文档:https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10

示例代码

import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

#ToTensor
tensor_trans = transforms.ToTensor()

train_set = torchvision.datasets.CIFAR10(root=r'D:\ai-learning\pytorch\cifar10', train=True, transform=tensor_trans, download=True)
test_set = torchvision.datasets.CIFAR10(root=r'D:\ai-learning\pytorch\cifar10', train=False, transform=tensor_trans, download=True)

#部分可视化
writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)
writer.close()

以CIFAR10数据集为例,其常用的参数
root=下载路径(如果没下载过,且设置download=True,则会下载至此路径;如果此路径中有已下载的数据,会校验)
train=True(提取此数据集中的训练集)train=False(提取此数据集中的测试集)
transforms=使用什么transforms方法(如果一个方法,可以直接transforms=torchvision.transforms.ToTensor;如果多个方法,先Compose)
download=True下载数据集至root路径,如果已有,则不再下载。建议常年True
*如果下载慢,可在help文档里查看此数据集的URL,复制至迅雷中下载,下载后把压缩文件复制至root目录中。运行代码时,会自动检测到下载好的数据集并校验、解压

看官方文档时,关注
1、数据集的内容,比如CIFAR10:10类,每类6k。训练集50k,测试集10k。大小32323
2、数据集的数据类型,比如CIFAR10就是PIL——要transform为Tensor类型
3、有哪些参数
4、看getitem返回什么内容,比如CIFAR10返回img, target,DataLoader后即为imgs, targets
—————————————————————————————————————
DataLoader
导入from torch.utils.data import DataLoader

常用参数
dataset=load什么数据集
batch_size=每次抽几张出来(默认是随机抽出)
shuffle=True(每个epoch是否洗牌)
num_works=0不开并行
drop_last=如果总数量除以batch_size除不尽,余数是否扔掉

DataLoader之后
load的数据按照batch_size打包,用for循环提取每个batch的数据:for data in test_loader
要去文档里看一下DataLoader的getitem返回哪些内容,比如CIFAR10返回的就是imgs, targets,所以imgs, targets = data
之后会把imgs送入神经网络

示例代码

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10(r'./cifar10', train=False, transform=torchvision.transforms.ToTensor(), download=True)

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

# img, target = test_data[0]

writer = SummaryWriter("logs")

for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch: {}".format(epoch), imgs, step) # 可视化
        step = step + 1
writer.close()
posted @ 2024-08-03 11:33  xjl_ultrasound  阅读(14)  评论(0编辑  收藏  举报