MXNet Data Iterator

MXNet Data Iterator

本文先就DataBatch、DataDesc、DataIter三个主要用到的类进行介绍,然后引出Mxnet中常见的迭代器。最后介绍一种为通用数据格式设计的数据迭代器DataLoaderIter。

 

DataBatch

MXNet中的数据迭代器Data iterators类似于Python迭代器对象。在Python中,函数iter允许通过对可iterable对象(如Python列表)调用next()按顺序获取项。迭代器提供了一个抽象接口,用于遍历各种类型的iterable集合,而无需公开底层数据源的详细信息。
在MXNet中,数据迭代器在每次调用next()时返回一批数据作为DataBatch。数据批处理通常包含n个训练示例及其相应的标签。这里n是迭代器的批处理大小。在数据流结束时,当没有更多的数据可读取时,迭代器会引发像Python iter那样的StopIteration异常。DataBatch结构在这

看看DataBatch类以及他的方法:

class mxnet.io.DataBatch(datalabel=Nonepad=Noneindex=Nonebucket_key=Noneprovide_data=Noneprovide_label=None)[source]

参数:

  • data:一个关于NDArray的列表,每个NDArray都包含了bachsize大小的样本。a list of input data
  • label:一个关于NDArray的列表,每个NDArray都包含了一维的标签信息。a list of input labels
  • pad: 整型,可选。在批处理结束时填充的样本数。当读取的样本总数不能被批大小整除时使用。这些额外的填充样本在预测中被忽略。 
  • index:numpy数组格式,可选。该批量中样本的索引
  • bucket_key:整型,可选。The bucket key, used for bucketing module.
  • provide_data:一个关于DataDesc的列表,可选。DataDesc用于存储数据的名字,形状,类型和格式信息。第i个元素描述了data[i]的名字和形状。
  • provide_label:一个关于DataDesc的列表,可选。DataDesc用于存储数据的名字,形状,类型和格式信息。第i个元素描述了label[i]的名字和形状。

这个类就是一个批量的样本,每次data iterator调用next(),就会返回一个DataBatch,也即一个批量的样本。如果输入的数据是图像的话,这些图像的shape取决于DataDesc中的provide_data参数:

 

DataDesc 

class mxnet.io.DataDesc[source]

DataDesc用于存储数据的名字,形状,类型和格式信息。 

参数: 

  • cls(DataDesc):类自己
  • name:字符串,数据名字
  • shape:元组或整型,数据形状
  • dtype:nd.dtype 可选。数据类型
  • layout:字符串,可选。数据格式。包括 NCHW\NHWC

方法:

  • get_batch_axis(layout):获取与批处理大小相对应的维度。
  • get_list(shapes, types):从属性列表中获取DataDesc列表。

每个训练样本的名称、形状、类型和布局等信息及其相应的标签可以通过DataBatch中的provide_data和provide_label属性作为DataDesc数据描述符对象提供。这里定义了DataDesc的结构。 

 

DataIter

class mxnet.io.DataIter(batch_size=0)[source]

mxnet中数据迭代器dataiter的基类。mxnet中所有的数据IO都由该类的子类来处理。mxnet中的dataiter迭代器是和python中的iterators很像,每次调用nxet都会返回一个Databatch代表了一个批量中的数据。

参数:

  • batch_size:批量大小。

方法:

  • getdata():获取当前批次的数据。
  • getindex():获取当前批的索引。
  • getlabel():获取当前批次的标签。
  • getpad():获取当前批处理中的填充样本数。
  • iter_next():移到下一批。
  • next():从迭代器获取下一个数据批。
  • reset():将迭代器重置为数据的开头。

 

Data iterators:Mxnet中所有常用的迭代器

 

 

MXNet中的所有IO都通过mx.io.DataIter以及它的子类来处理。本文将讨论MXNet提供的一些常用迭代器。

import mxnet as mx
%matplotlib inline
import os
import sys
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

 

Reading data in memory

import numpy as np

# fix the seed
np.random.seed(42)
mx.random.seed(42)

data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
    print([batch.data, batch.label, batch.pad])

 

Reading data from CSV files

#lets save `data` into a csv file first and try reading it back
np.savetxt('data.csv', data, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
    print([batch.data, batch.pad])

 

Custom Iterator

当所有内置的迭代器不能满足时,可以定制。

mxnet中的迭代器应当满足:

  • 1.如果是py2应实现nxet(),py3的话应实现__next()__,并返回一个DataBatcj或升起一个StopIteration意外当迭代到最后的时候。
  • 2. 实现reset()方法来返回到迭代器头部
  • 3. 实现provide_data属性
  • 4. 实现provide_label属性

创建新迭代器时,可以从头开始定义迭代器,也可以重用现有迭代器之一。例如,在图像caption应用程序中,输入示例是图像,而标签是句子。因此,我们可以通过以下方法创建新的迭代器:

  • 使用ImageRecordIter创建一个image_iter,它提供多线程预取和扩充。
  • 使用rnn包中提供的NDArrayIter或bucketing迭代器创建caption_iter。
  • next()返回image_iter.next()和caption_iter.next()

一个实例:

 1 class SimpleIter(mx.io.DataIter):
 2     def __init__(self, data_names, data_shapes, data_gen,
 3                  label_names, label_shapes, label_gen, num_batches=10):
 4         self._provide_data = list(zip(data_names, data_shapes))
 5         self._provide_label = list(zip(label_names, label_shapes))
 6         self.num_batches = num_batches
 7         self.data_gen = data_gen
 8         self.label_gen = label_gen
 9         self.cur_batch = 0
10 
11     def __iter__(self):
12         return self
13 
14     def reset(self):
15         self.cur_batch = 0
16 
17     def __next__(self):
18         return self.next()
19 
20     @property
21     def provide_data(self):
22         return self._provide_data
23 
24     @property
25     def provide_label(self):
26         return self._provide_label
27 
28     def next(self):
29         if self.cur_batch < self.num_batches:
30             self.cur_batch += 1
31             data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
32             label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
33             return mx.io.DataBatch(data, label)
34         else:
35             raise StopIteration

构建一个mlp:

import mxnet as mx
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())

通过mxnet的module模块来喂入数据。

import logging
logging.basicConfig(level=logging.INFO)

n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['softmax_label'], [(n,)],
                  [lambda s: np.random.randint(0, num_classes, s)])

mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)

因为data_iter是迭代器类型,所以可以有get_data()、get_label()、get_index()、next()等方法。

同样因为data_iter.next()返回的是一个DataBatch类型,所以可以有data_iter.next().data、data_iter.next().label等属性。

 

 

DataLoaderIter

class mxnet.contrib.io.DataLoaderIter(loaderdata_name='data'label_name='softmax_label'dtype='float32')[source]

正如以上提到的,这个类的父类也是DataIter。可以看到他的输入是dataloader,其实dataloader已经通过以下语句取到数据了:

 1 import mxnet as mx
 2 from mxnet import gluon
 3 import numpy as np
 4 
 5 transform = lambda data, label: (data.astype(np.float32)/255, label)       # transform
 6 train_imgs = gluon.data.vision.ImageFolderDataset(root='/Users/lps/MacProjects/mxnet/root',       # root路径
 7                                                   transform=transform)
 8 
 9 print(train_imgs.items)   #打印出所有图像信息: (filename, label) 对.
10 print(train_imgs.synsets)    # 列出所有类名 synsets[i] 是 label i所对应的类名
11  
12 dataloader = gluon.data.DataLoader(train_imgs, 2, shuffle=True)    # 类似pytorch,dataset需要放到dataloader里进行打包
13 iter(dataloader).__next__()     # 可以打印出一个批量的数据
14 
15 for batch in dataloader:    
16     print(batch[0].shape)   # 打印数据形状
17     print(batch[1])         # 打印标签

这里的数据是NDArray类型。这个类型可以在gluon中进行训练了。但是可以看到博客最开始提供的接口都是 iter 类型,所以可以通过这个DataLoaderIter将dataloader类型转为dataiter类型。官方api说这样做使gluon数据加载器可以在Symbol模块中使用。可以看到module的fit接口提供的train_data和eval_data类型都是DataIter型。所以通过这样转换,上面的数据既可以送到fit函数里了:

dataiter = mx.contrib.io.DataLoaderIter(dataloader)
for i in dataiter:
     print(i)
     break
  
输出:DataBatch: data shapes: [(2, 28, 28, 3)] label shapes: [(2,)]
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
mod.fit(train_data=dataiter, eval_data=..., optimizer='sgd'...)   # 送到fit里

也就是说symbol必须接受dataiter类型的数据,不接受dataloader类型,而gluon应该两种类型都接受,待补充🌰。

 

 

 

 

其余内容见:mxnet 数据读取

 

posted @ 2020-06-07 10:22  三年一梦  阅读(532)  评论(0编辑  收藏  举报