MindSpore易点通·精讲系列--数据集加载之ImageFolderDataset

Dive Into MindSpore – ImageFolderDataset For Dataset Load

MindSpore精讲系列–数据集加载之ImageFolderDataset

本文开发环境

  • Ubuntu 20.04
  • Python 3.8
  • MindSpore 1.7.0

本文内容摘要

  • 先看API
  • 简单示例
  • 深入探究
  • 本文总结
  • 遇到问题
  • 本文参考

1. 先看API

下面对主要参数做简单介绍:

  • dataset_dir – 数据集目录
  • num_samples – 读取的样本数,通常选用默认值即可
  • num_paraller_workers – 读取数据采用的线程数,一般为CPU线程数的1/4到1/2
  • shuffle – 是否打乱数据集,还是按顺序读取,默认为None。这里一定要注意,默认None并非是不打乱数据集,这个参数的默认值有点让人困惑。
  • extensions – 图片文件扩展名,可以为多个即list。如[“.JPEG”, “.png”],则读取文件夹相应扩展名的图片文件。if empty, read everything under the dir.
  • class_indexing – 文件夹名到label的索引映射字典
  • decode – 是否对图片数据进行解码,默认为False,即不解码
  • num_shards – 分布式场景下使用,可以认为是GPU或NPU的卡数
  • shard_id – 同上面参数在分布式场景下配合使用,可以认为是GPU或NPU卡的ID

2. 简单示例

本文使用的是Fruits 360数据集

2.1 解压数据

将Fruits 360数据集下载后,会得到archive.zip文件,使用unzip -x archive.zip命令进行解压。在同级目录下得到两个文件夹fruits-360_datasetfruits-360-original-size。使用命令tree -d -L 3 .对数据情况进行简单查看,输出内容如下:

.
├── fruits-360_dataset
│   └── fruits-360
│       ├── Lemon
│       ├── papers
│       ├── Test
│       ├── test-multiple_fruits
│       └── Training
└── fruits-360-original-size
    └── fruits-360-original-size
        ├── Meta
        ├── Papers
        ├── Test
        ├── Training
        └── Validation

本文将使用fruits-360_dataset文件夹。

2.2 最简用法

下面对fruits-360_dataset文件夹下的训练集fruits-360/Training进行加载。

代码如下:

参考1中参数介绍,需要将shuffle参数显示设置为False,否则无法复现。

from mindspore.dataset import ImageFolderDataset


def dataset_load(dataset_dir, shuffle=False, decode=False):
    dataset = ImageFolderDataset(
        dataset_dir=dataset_dir, shuffle=shuffle, decode=decode)

    data_size = dataset.get_dataset_size()
    print("data size: {}".format(data_size), flush=True)

    data_iter = dataset.create_dict_iterator()
    item = None
    for data in data_iter:
        item = data
        break

    # 打印数据
    print(item, flush=True)


def main():
    # 注意替换为个人路径
    train_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Training"

    #####################
    # test decode param #
    #####################
    dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)


if __name__ == "__main__":
    main()

将以上代码保存到load.py文件,使用如下命令运行:

python3 load.py

输出内容如下:

  • 数据集大小为67692,因为该文件夹下只有图片文件,也可以认为有67692个图片。
  • 数据包含两个字段:image和label。
  • image字段在decode参数为默认值False时,不对图片解码,所以可以认为是二进制数据,且其shape为一维的。
  • label字段已经进行了数值化转换。
data size: 67692
{'image': Tensor(shape=[4773], dtype=UInt8, value= [255, 216, 255, 224,   0,  16,  74,  70,  73,  70,   0,   1,   1,   0,   0,   1,   0,   1,   0,   0, 255, 219,   0,  67,
   0,   2,   1,   1,   1,   1,   1,   2,   1,   1,   1,   2,   2,   2,   2,   2,   4,   3,   2,   2,   2,   2,   5,   4,
   4,   3,   4,   6,   5,   6,   6,   6,   5,   6,   6,   6,   7,   9,   8,   6,   7,   9,   7,   6,   6,   8,  11,   8,
......
 251,  94, 126, 219, 218,  84,  16, 178,  91, 197, 168, 248,  91, 193, 130,  70, 243, 163, 144, 177, 104, 229, 186, 224,
 121, 120,   1,  92,  34, 146,  78, 229, 201,  92,  21, 175, 220, 146, 112,  51,  65,  32, 117,  52, 112,  69, 117,  66,
  10,  10, 200, 241, 234, 213, 157, 105, 243,  72,  40, 162, 138, 178,   2, 138,  40, 160,   2, 138,  40, 160,   2, 138,
  40, 160,   2, 138,  40, 160,   2, 138,  40, 160,   2, 138,  40, 160,   2, 138,  40, 160,  15, 255, 217]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}

2.3 是否解码

下面将decode参数设置为True,来看看数据情况。

将如下代码

dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)

修改为

dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=True)

使用如下命令,重新运行load.py文件。

python3 load.py

输出内容如下:

  • 数据集大小同2.2一致。
  • 数据包含两个字段:image和label。
  • 因为decode参数设置为True,已经对图片进行了解码,可以看到image字段的数据维度和数值已经有了变化。
  • label字段同2.2。
data size: 67692
{'image': Tensor(shape=[100, 100, 3], dtype=UInt8, value=
[[[254, 255, 255],
  [254, 255, 255],
  [254, 255, 255],
  ...
  [255, 255, 255],
  [255, 255, 255],
  [255, 255, 255]]]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}

3. 深入探究

在深入探究部分,本文来详细研究一下class_indexing参数,看看这个参数有什么意义。

首先本文提出一种异常情况,即训练集内的某个类别文件夹,在验证集/测试集不存在(可能因为数据极度不平衡或人为错误)。那么数据的标签id还能否对应好。

3.1 正常测试集

针对测试集,我们先做一次label统计。

代码如下:

import json

from mindspore.dataset import ImageFolderDataset


def label_check(dataset_dir, shuffle=False, decode=False, class_indexing=None):
    dataset = ImageFolderDataset(
        dataset_dir=dataset_dir, shuffle=shuffle, decode=decode, class_indexing=class_indexing)

    data_size = dataset.get_dataset_size()
    print("data size: {}".format(data_size), flush=True)

    data_iter = dataset.create_dict_iterator()

    label_dict = {}
    for data in data_iter:
        label_id = data["label"].asnumpy().tolist()
        label_dict[label_id] = label_dict.get(label_id, 0) + 1

    # 打印数据
    print("====== label dict ======\n{}".format(label_dict), flush=True)


def main():
    # 注意替换为个人路径
    test_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Test"
    label_check(dataset_dir=test_dataset_dir, shuffle=False, decode=False, class_indexing=None)


if __name__ == "__main__":
    main()

将上述代码保存到check.py文件,运行命令:

python3 check.py

输出内容如下:

  • 数据集大小为22688
  • 总共标签id为131
data size: 22688
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 59: 164, 60: 166, 61: 166, 62: 166, 63: 166, 64: 166, 65: 142, 66: 102, 67: 166, 68: 246, 69: 164, 70: 164, 71: 160, 72: 218, 73: 178, 74: 150, 75: 155, 76: 146, 77: 160, 78: 164, 79: 166, 80: 164, 81: 246, 82: 164, 83: 164, 84: 232, 85: 166, 86: 234, 87: 102, 88: 166, 89: 222, 90: 237, 91: 166, 92: 166, 93: 148, 94: 234, 95: 222, 96: 222, 97: 164, 98: 164, 99: 166, 100: 163, 101: 166, 102: 151, 103: 142, 104: 304, 105: 164, 106: 153, 107: 150, 108: 151, 109: 150, 110: 150, 111: 166, 112: 164, 113: 166, 114: 164, 115: 162, 116: 164, 117: 246, 118: 166, 119: 166, 120: 246, 121: 225, 122: 246, 123: 160, 124: 164, 125: 228, 126: 127, 127: 153, 128: 158, 129: 249, 130: 157}

3.2 异常测试集

为了进行测试,人为制造一些异常,将Test文件夹下的Lemon数据文件夹移动到上层目录。

命令如下:

cd {your_path}/fruits-360_dataset/fruits-360/Test
mv Lemon ../

3.2.1 未指定class_indexing

再次运行3.1中的check.py文件,输出内容如下:

  • 数据大小为22524
  • 总共标签id为130
data size: 22524
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 59: 166, 60: 166, 61: 166, 62: 166, 63: 166, 64: 142, 65: 102, 66: 166, 67: 246, 68: 164, 69: 164, 70: 160, 71: 218, 72: 178, 73: 150, 74: 155, 75: 146, 76: 160, 77: 164, 78: 166, 79: 164, 80: 246, 81: 164, 82: 164, 83: 232, 84: 166, 85: 234, 86: 102, 87: 166, 88: 222, 89: 237, 90: 166, 91: 166, 92: 148, 93: 234, 94: 222, 95: 222, 96: 164, 97: 164, 98: 166, 99: 163, 100: 166, 101: 151, 102: 142, 103: 304, 104: 164, 105: 153, 106: 150, 107: 151, 108: 150, 109: 150, 110: 166, 111: 164, 112: 166, 113: 164, 114: 162, 115: 164, 116: 246, 117: 166, 118: 166, 119: 246, 120: 225, 121: 246, 122: 160, 123: 164, 124: 228, 125: 127, 126: 153, 127: 158, 128: 249, 129: 157}

**解读:**仔细观察,可以看出3.2.1中的数据标签id已经同3.1中不同,也就是说如果我们是在训练后进行测试,那么标签id已经出错,测试结果肯定相当糟糕。

3.2.2 指定class_indexing

备注:这里我们默认训练数据集也使用了class_indexing字典文件进行数据加载,或者加载的标签ID与我们后期生成的一致。

为了能够与训练集的标签id保持一致,我们先利用训练集来生成class_indexing字典文件。

生成代码如下:

import json
import os


def make_class_indexing_file(dataset_dir, class_indexing_file):
    class_names = []
    for dir_or_file in os.listdir(dataset_dir):
        if os.path.isfile(dir_or_file):
            continue
        class_names.append(dir_or_file)

    sorted_class_names = sorted(class_names)
    print("num_classes: {}\n{}".format(len(sorted_class_names), "\n".join(sorted_class_names)), flush=True)

    class_indexing_dict = dict(zip(sorted_class_names, list(range(len(sorted_class_names)))))
    print("class_indexing dict: {}".format(class_indexing_dict), flush=True)

    with open(class_indexing_file, "w", encoding="UTF8") as fp:
        json.dump(class_indexing_dict, fp, indent=4, separators=(",", ": "))


def main():
    train_dataset_dir = "{your_path}/Fruits_360/fruits-360_dataset/fruits-360/Training"
    class_indexing_file = "{your_path}/Fruits_360/fruits-360_dataset/class_indexing.json"
    make_class_indexing_file(dataset_dir=dataset_dir, class_indexing_file=class_indexing_file)


if __name__ == "__main__":
    main()

保存代码到make_class_indexing.py文件,运行命令:

python3 make_class_indexing.py

备注:生成的字典文件为{your_path}/Fruits_360/fruits-360_dataset/class_indexing.json,读者可自行更改路径。

有了字典文件,再次修改check.py文件,修改为:

import json

from mindspore.dataset import ImageFolderDataset


def label_check(dataset_dir, shuffle=False, decode=False, class_indexing=None):
    dataset = ImageFolderDataset(
        dataset_dir=dataset_dir, shuffle=shuffle, decode=decode, class_indexing=class_indexing)

    data_size = dataset.get_dataset_size()
    print("data size: {}".format(data_size), flush=True)

    data_iter = dataset.create_dict_iterator()

    label_dict = {}
    for data in data_iter:
        label_id = data["label"].asnumpy().tolist()
        label_dict[label_id] = label_dict.get(label_id, 0) + 1

    # 打印数据
    print("====== label dict ======\n{}".format(label_dict), flush=True)


def load_class_indexing_file(class_indexing_file):
    with open(class_indexing_file, "r", encoding="UTF8") as fp:
        class_indexing_dict = json.load(fp)
    print("====== class_indexing_dict: ======\n{}".format(class_indexing_dict), flush=True)

    return class_indexing_dict


def main():
    # 注意替换为个人路径
    test_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Test"
    class_indexing_file = "{your_path}/fruits-360_dataset/class_indexing.json"
    class_indexing_dict = load_class_indexing_file(class_indexing_file)
    label_check(dataset_dir=test_dataset_dir, shuffle=False, decode=False, class_indexing=class_indexing_dict)


if __name__ == "__main__":
    main()

再次运行check.py文件,输出内容如下:

  • 数据大小同3.2.1中相同
  • 数据总标签id为131
  • 其中标签id为59数据为零,也就是我们上面移除的数据。
data size: 22524
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 60: 166, 61: 166, 62: 166, 63: 166, 64: 166, 65: 142, 66: 102, 67: 166, 68: 246, 69: 164, 70: 164, 71: 160, 72: 218, 73: 178, 74: 150, 75: 155, 76: 146, 77: 160, 78: 164, 79: 166, 80: 164, 81: 246, 82: 164, 83: 164, 84: 232, 85: 166, 86: 234, 87: 102, 88: 166, 89: 222, 90: 237, 91: 166, 92: 166, 93: 148, 94: 234, 95: 222, 96: 222, 97: 164, 98: 164, 99: 166, 100: 163, 101: 166, 102: 151, 103: 142, 104: 304, 105: 164, 106: 153, 107: 150, 108: 151, 109: 150, 110: 150, 111: 166, 112: 164, 113: 166, 114: 164, 115: 162, 116: 164, 117: 246, 118: 166, 119: 166, 120: 246, 121: 225, 122: 246, 123: 160, 124: 164, 125: 228, 126: 127, 127: 153, 128: 158, 129: 249, 130: 157}

4. 本文总结

本文主要讲解了MindSpore中的ImageFolderDataset数据集接口,并对其中的两个参数decodeclass_indexing进行了深入探究。

一个小建议:

笔者建议用户在使用ImageFolderDataset进行数据集加载时,人为指定class_indexing参数。毕竟相关字典文件的生成并没有几行代码,但对于类别数不一致的预训练模型(比如ImageNet22k和1k)或测试集出现人为问题的情况,可以有更好的保留空间。

5. 遇到问题

  • shuffle参数默认为None,却是对数据集进行了打乱,有点让人费解。

6. 本文参考

本文为原创文章,版权归作者所有,未经授权不得转载!

posted @ 2022-07-15 15:01  Skytier  阅读(130)  评论(0编辑  收藏  举报