从零开始学习MXnet(二)之dataiter

  MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。

  MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。

  

 1 import numpy as np
 2 class SimpleIter:
 3     def __init__(self, data_names, data_shapes, data_gen,
 4                  label_names, label_shapes, label_gen, num_batches=10):
 5         self._provide_data = zip(data_names, data_shapes)
 6         self._provide_label = zip(label_names, label_shapes)
 7         self.num_batches = num_batches
 8         self.data_gen = data_gen
 9         self.label_gen = label_gen
10         self.cur_batch = 0
11 
12     def __iter__(self):
13         return self
14 
15     def reset(self):
16         self.cur_batch = 0        
17 
18     def __next__(self):
19         return self.next()
20 
21     @property
22     def provide_data(self):
23         return self._provide_data
24 
25     @property
26     def provide_label(self):
27         return self._provide_label
28 
29     def next(self):
30         if self.cur_batch < self.num_batches:
31             self.cur_batch += 1
32             data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
33             assert len(data) > 0, "Empty batch data."
34             label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
35             assert len(label) > 0, "Empty batch label."
36             return SimpleBatch(data, label)
37         else:
38             raise StopIteration

  上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。

  下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:

  

  1 # pylint: skip-file
  2 import random
  3 
  4 import cv2
  5 import mxnet as mx
  6 import numpy as np
  7 import os
  8 from mxnet.io import DataIter, DataBatch
  9 
 10 
 11 class FileIter(DataIter): #一般都是继承DataIter
 12     """FileIter object in fcn-xs example. Taking a file list file to get dataiter.
 13     in this example, we use the whole image training for fcn-xs, that is to say
 14     we do not need resize/crop the image to the same size, so the batch_size is
 15     set to 1 here
 16     Parameters
 17     ----------
 18     root_dir : string
 19         the root dir of image/label lie in
 20     flist_name : string
 21         the list file of iamge and label, every line owns the form:
 22         index \t image_data_path \t image_label_path
 23     cut_off_size : int
 24         if the maximal size of one image is larger than cut_off_size, then it will
 25         crop the image with the minimal size of that image
 26     data_name : string
 27         the data name used in symbol data(default data name)
 28     label_name : string
 29         the label name used in symbol softmax_label(default label name)
 30     """
 31 
 32     def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117),
 33                  data_name="data", label_name="softmax_label", p=None):
 34         super(FileIter, self).__init__()
 35 
 36         self.fac = p.fac #这里的P是自己定义的config
 37         self.root_dir = root_dir
 38         self.flist_name = os.path.join(self.root_dir, flist_name)
 39         self.mean = np.array(rgb_mean)  # (R, G, B)
 40         self.data_name = data_name
 41         self.label_name = label_name
 42         self.batch_size = p.batch_size
 43         self.random_crop = p.random_crop
 44         self.random_flip = p.random_flip
 45         self.random_color = p.random_color
 46         self.random_scale = p.random_scale
 47         self.output_size = p.output_size
 48         self.color_aug_range = p.color_aug_range
 49         self.use_rnn = p.use_rnn
 50         self.num_hidden = p.num_hidden
 51         if self.use_rnn:
 52             self.init_h_name = 'init_h'
 53             self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden))
 54         self.cursor = -1
 55 
 56         self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
 57         self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
 58         self.data_list = []
 59         self.label_list = []
 60         self.order = []
 61         self.dict = {}
 62         lines = file(self.flist_name).read().splitlines()
 63         cnt = 0
 64         for line in lines: #读取lst,为后面读取图片做好准备
 65             _, data_img_name, label_img_name = line.strip('\n').split("\t")
 66             self.data_list.append(data_img_name)
 67             self.label_list.append(label_img_name)
 68             self.order.append(cnt)
 69             cnt += 1
 70         self.num_data = cnt
 71         self._shuffle()
 72 
 73     def _shuffle(self):
 74         random.shuffle(self.order)
 75 
 76     def _read_img(self, img_name, label_name):
 77      # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存
 78         if os.path.join(self.root_dir, img_name) in self.dict:
 79             img = self.dict[os.path.join(self.root_dir, img_name)]
 80         else:
 81             img = cv2.imread(os.path.join(self.root_dir, img_name))
 82             self.dict[os.path.join(self.root_dir, img_name)] = img
 83 
 84         if os.path.join(self.root_dir, label_name) in self.dict:
 85             label = self.dict[os.path.join(self.root_dir, label_name)]
 86         else:
 87             label = cv2.imread(os.path.join(self.root_dir, label_name),0)
 88             self.dict[os.path.join(self.root_dir, label_name)] = label
 89 
 90 
 91      # 下面是读取图片后的一系统预处理工作
 92         if self.random_flip:
 93             flip = random.randint(0, 1)
 94             if flip == 1:
 95                 img = cv2.flip(img, 1)
 96                 label = cv2.flip(label, 1)
 97         # scale jittering
 98         scale = random.uniform(self.random_scale[0], self.random_scale[1])
 99         new_width = int(img.shape[1] * scale)  # 680
100         new_height = int(img.shape[0] * scale)  # new_width * img.size[1] / img.size[0]
101         img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
102         label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
103         #img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST)
104         #label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST)
105         if self.random_crop:
106             start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1)
107             start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1)
108             img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :]
109             label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]]
110         if self.random_color:
111             img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
112             hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0])
113             sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1])
114             val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2])
115             img = np.array(img, dtype=np.float32)
116             img[..., 0] += hue
117             img[..., 1] += sat
118             img[..., 2] += val
119             img[..., 0] = np.clip(img[..., 0], 0, 255)
120             img[..., 1] = np.clip(img[..., 1], 0, 255)
121             img[..., 2] = np.clip(img[..., 2], 0, 255)
122             img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR)
123             is_rgb = True
124         #cv2.imshow('main', img)
125         #cv2.waitKey()
126         #cv2.imshow('maain', label)
127         #cv2.waitKey()
128         img = np.array(img, dtype=np.float32)  # (h, w, c)
129         reshaped_mean = self.mean.reshape(1, 1, 3)
130         img = img - reshaped_mean
131         img[:, :, :] = img[:, :, [2, 1, 0]]
132         img = img.transpose(2, 0, 1)
133         # img = np.expand_dims(img, axis=0)  # (1, c, h, w)
134 
135         label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac)
136         label_zoomed = label_zoomed.astype('uint8')
137         return (img, label_zoomed)
138 
139     @property
140     def provide_data(self):
141         """The name and shape of data provided by this iterator"""
142         if self.use_rnn:
143             return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])),
144                     (self.init_h_name, (self.batch_size, self.num_hidden))]
145         else:
146             return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))]
147 
148     @property
149     def provide_label(self):
150         """The name and shape of label provided by this iterator"""
151         return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))]
152 
153     def get_batch_size(self):
154         return self.batch_size
155 
156     def reset(self):
157         self.cursor = -self.batch_size
158         self._shuffle()
159 
160     def iter_next(self):
161         self.cursor += self.batch_size
162         return self.cursor < self.num_data
163 
164     def _getpad(self):
165         if self.cursor + self.batch_size > self.num_data:
166             return self.cursor + self.batch_size - self.num_data
167         else:
168             return 0
169 
170     def _getdata(self):
171         """Load data from underlying arrays, internal use only"""
172         assert(self.cursor < self.num_data), "DataIter needs reset."
173         data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
174         label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
175         if self.cursor + self.batch_size <= self.num_data:
176             for i in range(self.batch_size):
177                 idx = self.order[self.cursor + i]
178                 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
179                 data[i] = data_
180                 label[i] = label_
181         else:
182             for i in range(self.num_data - self.cursor):
183                 idx = self.order[self.cursor + i]
184                 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
185                 data[i] = data_
186                 label[i] = label_
187             pad = self.batch_size - self.num_data + self.cursor
188             #for i in pad:
189             for i in range(pad):
190                 idx = self.order[i]
191                 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
192                 data[i + self.num_data - self.cursor] = data_
193                 label[i + self.num_data - self.cursor] = label_
194         return mx.nd.array(data), mx.nd.array(label)
195 
196     def next(self):
197         """return one dict which contains "data" and "label" """
198         if self.iter_next():
199             data, label = self._getdata()
200             data = [data, self.init_h] if self.use_rnn else [data]
201             label = [label]
202             return DataBatch(data=data, label=label,
203                              pad=self._getpad(), index=None,
204                              provide_data=self.provide_data,
205                              provide_label=self.provide_label)
206         else:
207             raise StopIteration

    到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?

    其实也很简单,只需要修改几个地方:

      1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。

      2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。

    值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。

    总之......That's all~~~~

  

 

posted @ 2017-02-05 15:13  亦轩Dhc  阅读(8526)  评论(1编辑  收藏  举报