MindSpore易点通·精讲系列--数据集加载之ImageFolderDataset
Dive Into MindSpore – ImageFolderDataset For Dataset Load
MindSpore精讲系列–数据集加载之ImageFolderDataset
本文开发环境
- Ubuntu 20.04
- Python 3.8
- MindSpore 1.7.0
本文内容摘要
- 先看API
- 简单示例
- 深入探究
- 本文总结
- 遇到问题
- 本文参考
1. 先看API
下面对主要参数做简单介绍:
- dataset_dir – 数据集目录
- num_samples – 读取的样本数,通常选用默认值即可
- num_paraller_workers – 读取数据采用的线程数,一般为CPU线程数的1/4到1/2
- shuffle – 是否打乱数据集,还是按顺序读取,默认为None。这里一定要注意,默认None并非是不打乱数据集,这个参数的默认值有点让人困惑。
- extensions – 图片文件扩展名,可以为多个即list。如[“.JPEG”, “.png”],则读取文件夹相应扩展名的图片文件。if empty, read everything under the dir.
- class_indexing – 文件夹名到label的索引映射字典
- decode – 是否对图片数据进行解码,默认为False,即不解码
- num_shards – 分布式场景下使用,可以认为是GPU或NPU的卡数
- shard_id – 同上面参数在分布式场景下配合使用,可以认为是GPU或NPU卡的ID
2. 简单示例
本文使用的是Fruits 360数据集
- Kaggle 下载地址
- 启智平台 下载地址) – 对于无法访问
kaggle
的读者,可以采用启智平台。
2.1 解压数据
将Fruits 360数据集下载后,会得到archive.zip文件,使用unzip -x archive.zip
命令进行解压。在同级目录下得到两个文件夹fruits-360_dataset
和fruits-360-original-size
。使用命令tree -d -L 3 .
对数据情况进行简单查看,输出内容如下:
.
├── fruits-360_dataset
│ └── fruits-360
│ ├── Lemon
│ ├── papers
│ ├── Test
│ ├── test-multiple_fruits
│ └── Training
└── fruits-360-original-size
└── fruits-360-original-size
├── Meta
├── Papers
├── Test
├── Training
└── Validation
本文将使用fruits-360_dataset
文件夹。
2.2 最简用法
下面对fruits-360_dataset
文件夹下的训练集fruits-360/Training
进行加载。
代码如下:
参考1中参数介绍,需要将shuffle参数显示设置为False,否则无法复现。
from mindspore.dataset import ImageFolderDataset
def dataset_load(dataset_dir, shuffle=False, decode=False):
dataset = ImageFolderDataset(
dataset_dir=dataset_dir, shuffle=shuffle, decode=decode)
data_size = dataset.get_dataset_size()
print("data size: {}".format(data_size), flush=True)
data_iter = dataset.create_dict_iterator()
item = None
for data in data_iter:
item = data
break
# 打印数据
print(item, flush=True)
def main():
# 注意替换为个人路径
train_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Training"
#####################
# test decode param #
#####################
dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)
if __name__ == "__main__":
main()
将以上代码保存到load.py
文件,使用如下命令运行:
python3 load.py
输出内容如下:
- 数据集大小为67692,因为该文件夹下只有图片文件,也可以认为有67692个图片。
- 数据包含两个字段:image和label。
- image字段在decode参数为默认值False时,不对图片解码,所以可以认为是二进制数据,且其shape为一维的。
- label字段已经进行了数值化转换。
data size: 67692 {'image': Tensor(shape=[4773], dtype=UInt8, value= [255, 216, 255, 224, 0, 16, 74, 70, 73, 70, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 255, 219, 0, 67, 0, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 2, 4, 3, 2, 2, 2, 2, 5, 4, 4, 3, 4, 6, 5, 6, 6, 6, 5, 6, 6, 6, 7, 9, 8, 6, 7, 9, 7, 6, 6, 8, 11, 8, ...... 251, 94, 126, 219, 218, 84, 16, 178, 91, 197, 168, 248, 91, 193, 130, 70, 243, 163, 144, 177, 104, 229, 186, 224, 121, 120, 1, 92, 34, 146, 78, 229, 201, 92, 21, 175, 220, 146, 112, 51, 65, 32, 117, 52, 112, 69, 117, 66, 10, 10, 200, 241, 234, 213, 157, 105, 243, 72, 40, 162, 138, 178, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 15, 255, 217]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}
2.3 是否解码
下面将decode参数设置为True,来看看数据情况。
将如下代码
dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)
修改为
dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=True)
使用如下命令,重新运行load.py
文件。
python3 load.py
输出内容如下:
- 数据集大小同2.2一致。
- 数据包含两个字段:image和label。
- 因为decode参数设置为True,已经对图片进行了解码,可以看到image字段的数据维度和数值已经有了变化。
- label字段同2.2。
data size: 67692 {'image': Tensor(shape=[100, 100, 3], dtype=UInt8, value= [[[254, 255, 255], [254, 255, 255], [254, 255, 255], ... [255, 255, 255], [255, 255, 255], [255, 255, 255]]]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}
3. 深入探究
在深入探究部分,本文来详细研究一下class_indexing参数,看看这个参数有什么意义。
首先本文提出一种异常情况,即训练集内的某个类别文件夹,在验证集/测试集不存在(可能因为数据极度不平衡或人为错误)。那么数据的标签id还能否对应好。
3.1 正常测试集
针对测试集,我们先做一次label统计。
代码如下:
import json
from mindspore.dataset import ImageFolderDataset
def label_check(dataset_dir, shuffle=False, decode=False, class_indexing=None):
dataset = ImageFolderDataset(
dataset_dir=dataset_dir, shuffle=shuffle, decode=decode, class_indexing=class_indexing)
data_size = dataset.get_dataset_size()
print("data size: {}".format(data_size), flush=True)
data_iter = dataset.create_dict_iterator()
label_dict = {}
for data in data_iter:
label_id = data["label"].asnumpy().tolist()
label_dict[label_id] = label_dict.get(label_id, 0) + 1
# 打印数据
print("====== label dict ======\n{}".format(label_dict), flush=True)
def main():
# 注意替换为个人路径
test_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Test"
label_check(dataset_dir=test_dataset_dir, shuffle=False, decode=False, class_indexing=None)
if __name__ == "__main__":
main()
将上述代码保存到check.py
文件,运行命令:
python3 check.py
输出内容如下:
- 数据集大小为22688
- 总共标签id为131
data size: 22688
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 59: 164, 60: 166, 61: 166, 62: 166, 63: 166, 64: 166, 65: 142, 66: 102, 67: 166, 68: 246, 69: 164, 70: 164, 71: 160, 72: 218, 73: 178, 74: 150, 75: 155, 76: 146, 77: 160, 78: 164, 79: 166, 80: 164, 81: 246, 82: 164, 83: 164, 84: 232, 85: 166, 86: 234, 87: 102, 88: 166, 89: 222, 90: 237, 91: 166, 92: 166, 93: 148, 94: 234, 95: 222, 96: 222, 97: 164, 98: 164, 99: 166, 100: 163, 101: 166, 102: 151, 103: 142, 104: 304, 105: 164, 106: 153, 107: 150, 108: 151, 109: 150, 110: 150, 111: 166, 112: 164, 113: 166, 114: 164, 115: 162, 116: 164, 117: 246,