mxnet 数据读取
Mxnet的数据(图像和非图像)读取方式太太太多了。基于mxnet的、基于gluon的。 主要是Data iterators很多:
必读api:Iterators - Loading data
API给出的数据迭代器:主要是基于io和基于image的两大类:
这么多的数据读取方法,针对图像而言,Mxnet主要有四种:
Image IO
- 1. 利用mx.image.imdecode来load原始的图像数据
- 2. 利用
mx.img.ImageIter
来实现,非常灵活,易于定制(做transform),可以同时读取.rec格式和原始image格式,所以应当主要使用该方法。 - 3. 利用
mx.io.ImageRecordIter实现,基于C++后端,没那么灵活但是易于扩展到其他语言
- 4. 利用mx.io.DataIter来自己定制
方式1:
class mxnet.gluon.data.vision.datasets.
ImageFolderDataset
(root, flag=1, transform=None)[source]
基于 mxnet.gluon.data.dataset.Dataset的类。即继承自Dataset类。官网api地址
参数:
root:root文件夹路径
flag:0/1,如果是0则将输入图像转为灰度1通道,如果是1,则转为彩图3通道。
transform:可callable的,就是说类里要有call函数。默认是None。
用法:
用法非常简单,如果用过pytorch中torchvision的ImageFolder的话,这个基本一样了。
首先需要将数据做成这种格式:
root文件夹下有多个类别,每个类别是一个文件夹,里面是该类的数据。然后读取方法:
import mxnet as mx from mxnet import gluon import numpy as np transform = lambda data, label: (data.astype(np.float32)/255, label) # transform train_imgs = gluon.data.vision.ImageFolderDataset(root='/Users/lps/MacProjects/mxnet/root', # root路径 transform=transform) print(train_imgs.items) #打印出所有图像信息: (filename, label) 对. print(train_imgs.synsets) # 列出所有类名 synsets[i] 是 label i所对应的类名 data = gluon.data.DataLoader(train_imgs, 32, shuffle=True) # 类似pytorch,dataset需要放到dataloader里进行打包 iter(data).__next__() # 可以打印出一个批量的数据
方式2.
class mxnet.gluon.data.vision.datasets.
ImageListDataset
(root='.', imglist=None, flag=1)[source]
Bases: mxnet.gluon.data.dataset.Dataset。也是dataset类,所以用法和方式1基本一致。
参数:
root就是list文件的路径
imglist:
如果是一个纯txt文件,里面的格式应该是这样的:
如果是一个 .lst文件,则应该是这样的格式:
方式3.
class mxnet.io.
NDArrayIter
(data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', label_name='softmax_label')[source]
基于 Bases: mxnet.io.io.DataIter,即继承自DataIter类。
参数:易懂。
用法:可以看到这个类已经是个iter了,所以直接读取就行:
import mxnet as mx import numpy as np data = np.random.rand(100, 3) # 100个数据每个数据3特征 label = np.random.randint(0, 10, (100,)) # 100个标签 data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30) # 创建了类似dataloader的东西,可以迭代 for batch in data_iter: print([batch.data, batch.label, batch.pad], '\n')
方式4.
mxnet.io.
CSVIter
(*args, **kwargs))
提供从csv文件中读取的接口,同样也是一个iter。
#lets save `data` into a csv file first and try reading it back np.savetxt('data.csv', data, delimiter=',') # data_shape对应不上会报错 data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30) for batch in data_iter: print([batch.data, batch.pad])
方式5.
mxnet.recordio 模块 API
RecordIO是MXNet用于数据IO的文件格式,文件后缀为.rec。它紧凑地打包数据,以便从Hadoop HDFS和AWS S3等分布式文件系统进行高效的读写。 MXNet提供MXRecordIO 和MXIndexedRecordIO,用于数据的顺序访问和随机访问。
注意,.rec文件写入的必须是整数或者二进制数据。
该模块主要包括3个类:
IRHeader HEADER的别名
MXRecordIO
(uri, flag) 读写记录数据格式,支持顺序读写。
MXIndexedRecordIO
(idx_path, uri, flag[, …]) 读写记录数据格式,支持随机存取。
二进制数据的装包(mx.recordio.pack)与拆包(mx.recordio.unpack)。 pack 和unpack 用于存储浮点数(或1维浮点数组)标签和二进制数据。
图像数据的装包与拆包,由于图片数据在DL中尤为常用,所以单独给图片数组设计出接口,这个接口可以接收numpy数组,自动将之转化为二进制数据存入文件,解压时逆向操作。MXNet提供pack_img 和unpack_img 来打包/解压图像数据,pack_img 打包的记录可以由mx.io.ImageRecordIter 加载。
该模块主要包括4个函数:
pack
(header, s) 打包一个字符串到MXImageRecord
pack_img
(header, img[, quality, img_fmt]) 打包一个image到MXImageRecord
unpack
(s) 解包一个MXImageRecord到string
unpack_img
(s[, iscolor]) 解包一个MXImageRecord到图像
1-1. 第一个类IRHeader
1-2. 第二个类 class mxnet.recordio.
MXRecordIO
(uri, flag)
继承自object,也是第三个类MXIndexedRecordIO
的父类。
参数:
uri:字符串型,指向recored文件的路径
flag:字符串型,‘w’为写,‘r’为读
方法:
close():关闭 record文件
open():打开record文件
read(): 以字符串形式返回record
reset(): 重制指向第一项的指针
write(): 将字符串缓冲区作为record插入。
用法:
利用MXRecordIO顺序写入:
import mxnet as mx import numpy as np record = mx.recordio.MXRecordIO('tmp.rec', 'w') for i in range(5): record.write(b'record_%d'%i) record.close()
这时会生成一个tmp.rec文件。读取的时候:
import mxnet as mx import numpy as np record = mx.recordio.MXRecordIO('tmp.rec', 'r') for i in range(5): item = record.read() print(item)
1-3. 第三个类 class mxnet.recordio.
MXIndexedRecordIO
(idx_path, uri, flag, key_type=<class 'int'>)
参数:
idx_path: index文件的路径
uri:字符串型,指向recored文件的路径
flag:字符串型,‘w’为写,‘r’为读
key_type: keys的数据类型
方法:
close():关闭 record文件
open():打开record文件
read_idx(): 给定idx返回record
seek(): 设置当前读取指针位置
tell():返回写入头的当前位置。
write_idx(): 在给定索引处插入record。
用法:
MXIndexedRecordIO 支持随机或索引访问数据。 我们将创建一个索引记录文件和一个相应的索引文件,如下所示:
import mxnet as mx import numpy as np record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w') for i in range(5): record.write_idx(i, b'record_%d'%i) record.close()
这时会生成两个文件:tmp.idx和tmp.rec,读取他们:
record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r') record.read_idx(3) # 利用键来读
2-1 第一个和第三个函数:
mxnet.recordio.
pack
(header, s)[source]
mxnet.recordio.
unpack
(s)[source]
参数:
header:IRHeader类型,image record的header。header.label可以是数字或数组。参阅IRHeader中的更多详细信息。
s:字符串,要被打包的raw image string
返回 s:得到的打包好的string
用法:
# pack data = b'data' label1 = 1.0 header1 = mx.recordio.IRHeader(flag=0, label=label1, id=1, id2=0) s1 = mx.recordio.pack(header1, data) label2 = [1.0, 2.0, 3.0] header2 = mx.recordio.IRHeader(flag=3, label=label2, id=2, id2=0) s2 = mx.recordio.pack(header2, data) # unpack print(mx.recordio.unpack(s1)) print(mx.recordio.unpack(s2))
(HEADER(flag=0, label=1.0, id=1, id2=0), b'data') (HEADER(flag=3, label=array([ 1., 2., 3.], dtype=float32), id=2, id2=0), b'data')
2-2 第二个和第四个函数:
mxnet.recordio.
pack_img
(header, img, quality=95, img_fmt='.jpg')[source]
mxnet.recordio.
unpack_img
(s, iscolor=-1)[source]
用法:
data = np.ones((3,3,1), dtype=np.uint8) label = 1.0 header = mx.recordio.IRHeader(flag=0, label=label, id=0, id2=0) s = mx.recordio.pack_img(header, data, quality=100, img_fmt='.jpg') # unpack_img print(mx.recordio.unpack_img(s))
(HEADER(flag=0, label=1.0, id=0, id2=0),
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=uint8))
方式6 使用im2rec进行打包
MXNet框架用于做图像相关的项目时,读取图像主要有两种方式:
- 第一种是读.rec格式的文件,类似Caffe框架中LMDB,优点是.rec文件比较稳定,移植到别的电脑上也能复现,缺点是占空间(.rec文件的大小基本上和图像的存储大小差不多),而且增删数据不大灵活。需要idx搭配使用,下面脚本会一并生成。
- 第二种是.lst和图像结合的方式,首先在前面生成.rec文件的过程中也会生成.lst文件,这个.lst文件就是图像路径和标签的对应列表,也就是说通过维护这个列表来控制你训练集和测试集的变化,优点是灵活且不占空间,缺点是如果图像格式不符合要求的话容易出错而且如果列表中的某些图像路径对应的图像文件夹中图像被删除,就寻找不到,另外如果你不是从固态硬盘上读取图像的话,速度会很慢。
1.生成.lst
数据准备:
# .
# └── MacProjects
# ├── mxnet
# ├── im2rec.py
# └── images
# ├── cat
# └── dog
那么运行下面的命令就可以生成.lst文件:
python3 im2rec.py --list --recursive --train-ratio 0.9 mxrec/dog_cat_cls images
--list 说明要产生lst文件
--recursive 遍历所有子文件夹,炳辉给每个子文件夹一个编号
--train_ratio 确定训练集和测试集的比例
mxrec/dog_cat_cls 指的是文件命名前缀,存下来的文件会在mxrec文件夹:两个文件:dog_cat_cls_train.lst、dog_cat_cls_val.lst
images:要遍历的文件夹名字
2.生成.rec
python3 im2rec.py mxrec/dog_cat_cls images --resize 16 --num-thread 4
得到lst文件后就可以根据该文件以及图像,生成rec和idx文件
参数同上,--num-thread 表示线程数。 --resize就是将图像resize后保存
此时得到的文件共六个,train/val 以及各自有.lst、.idx、.rec文件。
3. 数据读取。得到rec文件后,有两种方式来读取:MXNet的图像数据导入模块主要有mxnet.io.ImageRecordIter和mxnet.image.ImageIter两个类,前者主要用来读取.rec格式的数据,后者既可以读.rec格式文件,也可以读原图像数据。
3.1 mxnet.io.
ImageRecordIter
(*args, **kwargs)
参数:及其多,建议直接看api
返回类型:MXDataIter
import mxnet as mx import matplotlib.pyplot as plt import numpy as np data_iter = mx.io.ImageRecordIter( path_imgrec="/Users/bytedance/MacProjects/mxrec/dog_cat_cls_val.rec", # the target record file data_shape=(3, 16, 16), # output data shape. An 227x227 region will be cropped from the original image. batch_size=4, # number of samples per batch resize=16 # resize the shorter edge to 256 before cropping # ... you can add more augumentation options as defined in ImageRecordIter. ) data_iter.reset() # Reset the iterator to the begin of the data. batch = data_iter.next() data = batch.data[0] for i in range(4): plt.subplot(1,4,i+1) plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0))) plt.show()
data_iter.getdata() # 得到批量数据
data_iter.getindex # 得到批量的index
data_iter.getlabel() # 得到批量标签
3.2 class mxnet.image.
ImageIter
(batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, data_name='data', label_name='softmax_label', dtype='float32', last_batch_handle='pad', **kwargs)[source]
import mxnet as mx import matplotlib.pyplot as plt import numpy as np # ImageIter 是一个灵活的界面,支持以RecordIO和Raw格式加载图像 data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227), path_imgrec="/Users/bytedance/MacProjects/mxrec/dog_cat_cls_val.rec", # 需要rec和idx path_imgidx="/Users/bytedance/MacProjects/mxrec/dog_cat_cls_val.idx") data_iter.reset() batch = data_iter.next() data = batch.data[0] for i in range(4): plt.subplot(1, 4, i + 1) plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1, 2, 0))) plt.show()
mxnet.image.ImageIter是一个非常重要的类。在MXNet中,当你要读入图像数据时,可以用im2rec.py生成lst和rec文件,然后用mxnet.io.ImageRecordIter类来读取rec文件或者用这个mxnet.image.ImageIter类来读取rec文件,但是这个函数和前者相比还能直接读取图像文件,这样就可以不用生成占内存的rec文件了,只需要原图像文件和lst文件即可。另外,在mxnet.io.ImageRecordIter中对于数据的预处理操作都是固定的,不好修改,但是mxnet.image.ImageIter却可以非常灵活地添加各种预处理操作。
使用.lst和图像时,示意如下:
import mxnet as mx import matplotlib.pyplot as plt import numpy as np # ImageIter 是一个灵活的界面,支持以RecordIO和Raw格式加载图像 data_iter = mx.image.ImageIter( batch_size = 3, data_shape = (3,16,16), label_width = 1, path_imglist = "/Users/bytedance/MacProjects/mxrec/dog_cat_cls_val.lst", # 需要lst和images path_root = '/Users/bytedance/MacProjects/images', part_index = 0, shuffle = True, data_name = 'data', label_name = 'softmax_label', aug_list = mx.image.CreateAugmenter((3,16,16),resize=16,rand_crop=True,rand_mirror=True,mean=True)) data_iter.reset() batch = data_iter.next() data = batch.data[0] for i in range(4): plt.subplot(1, 4, i + 1) plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1, 2, 0))) plt.show()
这里的path_imglist参数和path_root参数是这个类特有的,分别表示.lst文件和图像的路
只是一个列表文件,大大节省了存储空间,也方便以后对数据的增删改变,因为只要重新生成.lst文件即可,而不需要花时间生成占空间的.rec文件。
如果aug_list这个参数没有赋值(默认是None),那么就不对图像做预处理;如果这个参数有值,那么就调用CreateAugmenter()函数生成预处理列表。
CreateAugmenter:
1 def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, rand_mirror=False,mean=None, 2 std=None, brightness=0, contrast=0, saturation=0, 3 pca_noise=0, inter_method=2): 4 """Creates an augmenter list.""" 5 auglist = [] 6 7 8 # resize这个参数很重要,一般都要做resize,如果你的resize参数设置为224,你的原图像是350*300,那么最后resize的大小就是 9 # (350*300/224)*224。这里ResizeAug()函数调用resize_short()函数,resize_short()函数调用OpenCV的imresize()函数完成resize 10 # ,interp参数为2表示采用双三次插值做resize,可以参考:http://docs.opencv.org/master/da/d54/group__imgproc__transform.html。 11 if resize > 0: 12 auglist.append(ResizeAug(resize, inter_method)) 13 14 crop_size = (data_shape[2], data_shape[1]) 15 16 # 如果rand_resize参数是true,那么会调用RandomSizedCropAug()函数,输入是size,min_area,retio,interp, 17 # 这个函数既做resize又做crop,因此这边才会写成if elif的语句。RandomSizedCropAug()函数调用random_size_crop()函数, 18 # 这个函数会先生成随机的坐标点和长宽值,然后调用fixed_crop()函数做crop。 19 #这里还有一个语句是assert rand_crop,python的assert语句是用来声明其布尔值必须为真,如果表达式为假,就会触发异常。 20 # 也就是说要调用RandomSizedCropAug()函数的前提是rand_crop是True。 21 if rand_resize: 22 assert rand_crop 23 auglist.append(RandomSizedCropAug(crop_size, 0.3, (3.0 / 4.0, 4.0 / 3.0), inter_method)) 24 25 #如果rand_crop参数是true,表示随机裁剪,randomCropAug()函数的输入之一是crop_size, 26 # 这个crop_size就是CreateAugmenter()函数的输入data_shape的图像大小,然后randomCropAug()函数调用random_crop()函数, 27 # random_crop()函数会先生成新的长宽值和坐标点,然后以此调用fixed_crop()函数做crop, 28 # 最后返回crop后的图像和坐标即长宽值,因为生成中心坐标点的时候是随机的,所以还是random crop。 29 elif rand_crop: 30 auglist.append(RandomCropAug(crop_size, inter_method)) 31 32 # 如果前面两个if条件都不满足,就调用CenterCropAug()函数做crop,这个函数的输入也包括了crop_size,也就是你的输入data_shape, 33 # 所以这个参数是很有用的。CenterCropAug()函数调用center_crop()函数,这个函数的输入输出都是NDArray格式。 34 # center_crop()函数和random_crop()函数的区别在于前者坐标点的生成不是随机的,而是和原图像一样, 35 # 然后再将坐标点和新的长宽作为fixed_crop()函数的输入。 36 else: 37 auglist.append(CenterCropAug(crop_size, inter_method)) 38 #可以看出不管你是否要做crop,只要你给定了data_shape参数,就默认要将输入图像做crop操作。 39 # 因此如果你不想在test的时候做crop,可以在这修改源码。 40 41 # 随机镜像处理,参数是0.5,HorizontalFlipAug()函数调用nd.flip()函数做水平翻转 42 if rand_mirror: 43 auglist.append(HorizontalFlipAug(0.5)) 44 45 46 # CastAug()函数主要是将数据格式转化为float32 47 auglist.append(CastAug()) 48 49 50 # 这三个参数分别是亮度,对比度,饱和度。当你对这三个参数设置了值, 51 # 就会调用ColorJitterAug()函数对其相应的亮度或对比度或饱和度做改变 52 if brightness or contrast or saturation: 53 auglist.append(ColorJitterAug(brightness, contrast, saturation)) 54 55 56 # 这个部分主要是添加pca噪声的,具体可以看LightingAug()函数 57 if pca_noise > 0: 58 eigval = np.array([55.46, 4.794, 1.148]) 59 eigvec = np.array([[-0.5675, 0.7192, 0.4009], 60 [-0.5808, -0.0045, -0.8140], 61 [-0.5836, -0.6948, 0.4203]]) 62 auglist.append(LightingAug(pca_noise, eigval, eigvec)) 63 64 65 # mean这个参数主要是和归一化相关。这里的assert语句前面已经介绍过了。mean参数默认是None,这种情况下是不会进入下面的if elif条件函数的。 66 # 如果想进行均值操作,可以设置mean为True,那么就会进入第一个if条件,如果你设置为其他值,就会进入elif条件, 67 # 这个时候如果你的mean不符合要求,比如isinstance函数用来判断类型,就会触发异常。 68 if mean is True: 69 mean = np.array([123.68, 116.28, 103.53]) 70 elif mean is not None: 71 assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3] 72 73 # std与mean同理 74 if std is True: 75 std = np.array([58.395, 57.12, 57.375]) 76 elif std is not None: 77 assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3] 78 79 # 这里需要mean和std同时都设置正确才能进行预处理,如果你只设置了mean,没有设置std,那么还是没有启动归一化的预处理。 80 # 这里主要调用ColorNormalizeAug()函数,这个函数调用color_normalize()函数,这个函数的实现很简单, 81 # 就是将原图像的像素值减去均值mean,然后除以标准差std得到返回值。 82 if mean is not None and std is not None: 83 auglist.append(ColorNormalizeAug(mean, std)) 84 85 # 最后返回预处理的列表 86 return auglist
Ref:
『MXNet』第八弹_数据处理API_上
第八弹_数据处理API_下_Image IO专题
im2rec脚本使用以及数据读取