pytorch 的 torchvision.datasets.ImageFolder 来自定义数据集

import torchvision
class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
YOLOv5 Classification Dataset.
Arguments
root: Dataset path
"""
def __init__(self, root):
super().__init__(root=root) # 调用了 父类的 初始化函数,就拥有了以下的 self 属性
classes = self.classes # list 每个类的文件名
# ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
class_to_idx = self.class_to_idx # 字典 每个类的文件名,类别标签(数字)
# {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
samples = self.samples # list 图像路径,标签(0,1,2...)
# [('/data/huyuzhen/proje...in/0/1.png', 0), ('/data/huyuzhen/proje...0/1000.png', 0),...
targets = self.targets # list 类别标签 数字:0,1,2...
# [0, 0, 0, 0, 0, 0, 0...
path = '/data/huyuzhen/projects/datasets/mnist/train'
dataset = ClassificationDataset(root=path)

自定义一个图像分类 类,mnist 数据组织为 :

mnist
├── test
│ ├── 0
│ ├── 1
...
├── train
│ ├── 0
│ ├── 1
...

ImageFolder是DatasetFolder的子类,有以下属性:

Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""

使用 torchvision.datasets.ImageFolder 需要把数据集按如上组织。

posted @   Zenith_Hugh  阅读(153)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~
点击右上角即可分享
微信分享提示

喜欢请打赏

扫描二维码打赏

微信打赏