码头牛牛Blog点我刷新(小声)

深度学习(四)——torchvision中数据集的使用

码头牛牛·2023-07-14 00:09·402 次阅读

深度学习(四)——torchvision中数据集的使用

一、 科研数据集#

下载链接:

https://pytorch.org/vision/stable/index.html

本文中我们使用的是CIFAR数据集

二、CIFAR10数据集详解#

具体网站:

CIFAR10 — Torchvision 0.15 documentation

1. 参数详解#

  • torchvision中每个数据集的参数都是大同小异的,这里只介绍CIFAR10数据集

  • 该数据集的数据格式为PIL格式

Copy
class torchvision.datasets.CIFAR10(root:str,train:bool=True,transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,download:bool=False)
  • 内置函数:

    • root(string):必须设置,输入数据集下载后存放在电脑中的路径

    • train(bool):True代表创建的一个训练集(train);False代表创建一个测试集(test)。

    • transform:对数据集中的数据进行变换

    • target_transform:对标签(target)数据进行变换

    • download(bool):True的时候会自动从网上下载这个数据集,False的时候则不会下载该数据集。

  • 代码示例:

    • 运行后直接下载数据集

    • 需要注意的是,如果下载速度过慢,则可以在运行后,把弹出的网址单拎出来,放到迅雷等软件上进行下载

Copy
import torchvision #设置训练集 #root:设置为相对路径,会在该.py文件下设置一个名为dataset的文件存放CIFAR10数据 #train: True,数据集为训练集 #download: 下载该数据集 train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True) #设置测试集;train=False test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
  • 数据标签查看:

    • 在运行上面的代码下载好数据集后,输入print(test_set[0),并使用一下pycharm的dubug功能,不难发现:

    • 也就是说,数据标签有'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'十类,分别用整数0~9来表示

    • 数据集包含的所有标签也可以用下面的代码打印出来:

Copy
print(test_set.classes) #[Run] [airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  • 某条数据的PIL Image、标签的获取方法:img,target=test_set[索引]
Copy
img,target=test_set[0] print(img) print(target,test_set.classes[target]) #[Run] #<PIL.Image.Image image mode=RGB size=32x32 at 0x1DDF9FCD640> #3 cat
  • 显示图片:
Copy
img.show()

三、使用transform处理多组图像数据#

代码示例#

  • 首先使用Compose去定义如何处理PIL图像数据

  • 然后代入torchvision.datasets.CIFAR10中,处理里面的图像数据

Copy
#首先用Compose处理图像数据,可以先转为tensor格式,然后再裁剪等,这里只转tensor格式 import torchvision dataset_transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) #定义transform=dataset_transform,使得图像数据类型转换为Compose中处理过后的 train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True) test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
  • 对处理过后的图像进行可视化操作
Copy
from torch.utils.tensorboard import SummaryWriter writer=SummaryWriter("p10") for i in range(10): #显示test_set数据集中的前十张图片 img,target=test_set[i] writer.add_image("test_set",img,i) writer.close()
posted @   码头牛牛  阅读(402)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
点击右上角即可分享
微信分享提示
目录