pytorch 的 torchvision.datasets.ImageFolder 来自定义数据集
import torchvision class ClassificationDataset(torchvision.datasets.ImageFolder): """ YOLOv5 Classification Dataset. Arguments root: Dataset path """ def __init__(self, root): super().__init__(root=root) # 调用了 父类的 初始化函数,就拥有了以下的 self 属性 classes = self.classes # list 每个类的文件名 # ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] class_to_idx = self.class_to_idx # 字典 每个类的文件名,类别标签(数字) # {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9} samples = self.samples # list 图像路径,标签(0,1,2...) # [('/data/huyuzhen/proje...in/0/1.png', 0), ('/data/huyuzhen/proje...0/1000.png', 0),... targets = self.targets # list 类别标签 数字:0,1,2... # [0, 0, 0, 0, 0, 0, 0... path = '/data/huyuzhen/projects/datasets/mnist/train' dataset = ClassificationDataset(root=path)
自定义一个图像分类 类,mnist 数据组织为 :
mnist ├── test │ ├── 0 │ ├── 1 ... ├── train │ ├── 0 │ ├── 1 ...
ImageFolder是DatasetFolder的子类,有以下属性:
Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """
使用 torchvision.datasets.ImageFolder 需要把数据集按如上组织。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~