dataset模块之Cifar10Dataset类(1)

在开始之前,首先声明本篇文章参考官方文档,我基于官网的这篇文章加以自己的理解发表了这篇博客,希望大家能够更快更简单直观的体验MindSpore,如有不妥的地方欢迎大家指正。

【本文代码编译环境为MindSpore1.3.0 CPU版本】

本篇文章讲的主要是图像领域经典数据集CIFAR-10的加载,由于图像领域的经典数据集之间是有着很多相似之处的,所以我相信读者在学习完CIFAR-10的加载后,一定可以举一反三,实现其它MindSpore所支持数据集的加载。

关于CIFAR-10数据集:

CIFAR-10数据集包括10个类别的60000张32x32彩色图像,每个类别6000张图像。有50000个训练图像和10000个测试图像。这10个不同的类别代表飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。

以下是原始的CIFAR-10数据集结构。您可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API(Cifar10Dataset类)读取。

.
└── cifar-10-batches-bin
     ├── data_batch_1.bin
     ├── data_batch_2.bin
     ├── data_batch_3.bin
     ├── data_batch_4.bin
     ├── data_batch_5.bin
     ├── test_batch.bin
     ├── readme.html
     └── batches.meta.txt

附下载链接:https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz(下载速度可能较慢,不要着急)

既然是采用Cifar10Dataset类来实现对CIFAR-10数据集的加载,那我们不妨来看一下Cifar10Dataset类的源码:

class Cifar10Dataset(MappableDataset):
    @check_mnist_cifar_dataset
    def __init__(self, dataset_dir, usage=None,num_samples=None,num_parallel_workers=None
                 , shuffle=None,
                 sampler=None, num_shards=None, shard_id=None, cache=None):
    super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
    self.dataset_dir = dataset_dir
    self.usage = replace_none(usage, "all")
     def parse(self, children=None):
        return cde.Cifar10Node(self.dataset_dir, self.usage, self.sampler)

大家看到这段简单的源代码一定会感到奇怪,Cifar10Dataset类的初始化参数这么多,但代码里面又没有其它方法,那么它需要这么多参数到底是干什么的?而且,上一篇文章咱们不是讲过,实例化之后的Cifar10Dataset对象具有很多的获取数据集信息的方法,方法呢?

大家再仔细看一下就会发现Cifar10Dataset类是继承自MappableDataset类的,而MappableDataset类又是继承自SourceDataset类的,最终SourceDataset类又是继承自Dataset类的。那为什么要这样写呢?归根结底的一个原因是利用了面向对象编程语言的特性,将某个类封装为实现某些具体方法。dataset模块可以加载各种各样的数据集以及不同格式的标准数据集,对于这些不同的数据集,不说他们的方法全部相似,但至少会有很多相似的部分,比如说(get_col_names,get_batch_size),这些方法是数据集类都要实现的,我们总不能在每一个数据类中都写一遍吧。因此,这里采用了类继承的机制,将某些公共的方法封装在底层的类中。这样既减少了代码量,还增强了代码的逻辑性,强壮性。

虽然看起来有点复杂,但我们也不要着急,一步一步来。

Cifar10Dataset类

函数修饰器:@check_mnist_cifar_dataset

def check_mnist_cifar_dataset(method):
    """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""

    @wraps(method)
    def new_method(self, *args, **kwargs):
        _, param_dict = parse_user_args(method, *args, **kwargs)

        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
        nreq_param_bool = ['shuffle']

        dataset_dir = param_dict.get('dataset_dir')
        check_dir(dataset_dir)

        usage = param_dict.get('usage')
        if usage is not None:
            check_valid_str(usage, ["train", "test", "all"], "usage")

        validate_dataset_param_value(nreq_param_int, param_dict, int)
        validate_dataset_param_value(nreq_param_bool, param_dict, bool)

        check_sampler_shuffle_shard_options(param_dict)

        cache = param_dict.get('cache')
        check_cache_option(cache)

        return method(self, *args, **kwargs)

    return new_method

上述代码理解起来并不难,它就是一个围绕(ManifestDataset, Cifar10/100Dataset)数据集的参数检查的方法的包装器。这是因为(ManifestDataset, Cifar10/100Dataset)数据集类的初始化方法需要传入上诉代码中的参数,框架必须对输入的参数进行严格的检查。同时为了易用,简洁,实用,我们将这个方法包装起来成为一个函数修饰器。

初始化函数: def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):

对于初始化函数,我们需要理解的是它的各个参数的含义:

dataset_dir:数据格式:str 含义:包含CIFAR-10数据集的根目录的路径

usage:数据格式:str 含义:此数据集的用法,可以是训练(train),测试(test),全部(all)。usage可以是这三个字符串中的一个。train:读取50000个训练样本;test:读取10000个测试样本;all:读取全部60000个样本。

num_samples:数据格式:int 含义:需要包含在数据集中的图像数,默认值无(None),全部图像

num_parallel:数据格式:int 含义:读取数据的worker数,即同时有多少个进程在读取数据,一般使用默认值None,即配置中所默认的数

shuffle:数据格式:bool 含义:是否对数据进行混淆,默认值无(None)即最后数据集对象中的数据和数据集中的顺序是一样的

sampler:数据格式:实例化的一个采样器对象 含义:用于定义从数据集中选择样本的规则,默认值无(None),预期顺序不变

num_shards:数据格式:int 含义:将数据集划分为几块,默认值无(None),若指定该参数,则num_samples的含义会变为单个块中的最大图像数。

shared_id:数据格式:int 含义:默认值无(None),只要当指定了num_shards参数后,才能指定该参数

,为num_shards块的id

cache:数据格式:实例化的DatasetCache对象 含义:使用tensor缓存服务加速数据集处理,默认值无,注意在windows上是不支持缓存的

接下来,我们来看一下使用Cifar10Dataset类来加载我们下载好的Cifar10数据集。首先,先解压下载好的数据集并将其按以下的目录进行放置。

./cifar-10-batches-bin
├── readme.html
├── test
│   └── test_batch.bin
└── train
    ├── batches.meta.txt
    ├── data_batch_1.bin
    ├── data_batch_2.bin
    ├── data_batch_3.bin
    ├── data_batch_4.bin
    └── data_batch_5.bin

2 directories, 8 files

然后就是我们的代码部分:

from mindspore import dataset as ds
import matplotlib.pyplot as plt
data_path = "./cifar-10-batches-bin/train"

# 随机采样器,读取10张图片
sampler = ds.RandomSampler(num_samples=10)

# 创建字典对象,方便将其图像显示出来,以及打印相应的信息
data_source = ds.Cifar10Dataset(data_path, sampler=sampler)

data = data_source.create_dict_iterator()

item = next(data)
# 这一步分不必多讲,我们在手写数字识别初体验中已经描述过,记住MindSpore中的多维数组基本都是张量,当使用
# python的库(包括numpy和matplotlib.pyplot),要通过asnumpy()将张量转化为numpy数据类型
image = item["image"].asnumpy()
label = item["label"].asnumpy()
plt.title(item['label'].asnumpy())
print(image.shape)

plt.imshow(image)
plt.show()

for d in data:
    print("Image shape:", d['image'].shape, ", Label:", d['label'])

对于该数据集,还有一点需要记住,就是它的lable是0到9,每一个数字表示一个物品,这样可以方便进行监督学习。对应关系放在batches.meta.txt中,具体的如下所示:

0   airplane
1   automobile
2   bird
3   cat
4   deer
5   dog
6   frog
7   horse
8   ship
9   truck

运行结果若下图所示:

image.png

image.png

可以看到,数据集中每张图片的大小32 x 32 x 3,这个形状在我们训练过程中是需要变成3 x 32 x 32的,方便计算。

上面一张图片的标签值是2,对应的物种确实是一只鸟,只是看着觉得有点别扭,和我们拍照是不一样的,可能它这个图片是已经处理过的了。这个图片分类的问题是比手写数字识别更加复杂的,因此LeNet5网络不会取得很好的效果。而实际上确实也有处理它较好的一个网络,那就是是VGG16,不过这个网络的参数比LeNet5不知道大到哪去了,推荐有GPU环境或者是可以云上训练的读者,去model_zoo中找到相应的网络模型去跑一下,这样可以更深刻的了解MindSpore。

好了,本篇文章到这里就结束了,下一篇文章,我将给大家介绍Cifar10Dataset类的父类,以及它父类的父类。会了这些的话,就差不多可以理解其它的数据集类了。我们下篇文章再见。

posted @ 2021-12-30 18:17  MS小白  阅读(85)  评论(0)    收藏  举报