从零开始学习MXnet(二)之dataiter
MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。
MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。
1 import numpy as np 2 class SimpleIter: 3 def __init__(self, data_names, data_shapes, data_gen, 4 label_names, label_shapes, label_gen, num_batches=10): 5 self._provide_data = zip(data_names, data_shapes) 6 self._provide_label = zip(label_names, label_shapes) 7 self.num_batches = num_batches 8 self.data_gen = data_gen 9 self.label_gen = label_gen 10 self.cur_batch = 0 11 12 def __iter__(self): 13 return self 14 15 def reset(self): 16 self.cur_batch = 0 17 18 def __next__(self): 19 return self.next() 20 21 @property 22 def provide_data(self): 23 return self._provide_data 24 25 @property 26 def provide_label(self): 27 return self._provide_label 28 29 def next(self): 30 if self.cur_batch < self.num_batches: 31 self.cur_batch += 1 32 data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)] 33 assert len(data) > 0, "Empty batch data." 34 label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)] 35 assert len(label) > 0, "Empty batch label." 36 return SimpleBatch(data, label) 37 else: 38 raise StopIteration
上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。
下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:
1 # pylint: skip-file 2 import random 3 4 import cv2 5 import mxnet as mx 6 import numpy as np 7 import os 8 from mxnet.io import DataIter, DataBatch 9 10 11 class FileIter(DataIter): #一般都是继承DataIter 12 """FileIter object in fcn-xs example. Taking a file list file to get dataiter. 13 in this example, we use the whole image training for fcn-xs, that is to say 14 we do not need resize/crop the image to the same size, so the batch_size is 15 set to 1 here 16 Parameters 17 ---------- 18 root_dir : string 19 the root dir of image/label lie in 20 flist_name : string 21 the list file of iamge and label, every line owns the form: 22 index \t image_data_path \t image_label_path 23 cut_off_size : int 24 if the maximal size of one image is larger than cut_off_size, then it will 25 crop the image with the minimal size of that image 26 data_name : string 27 the data name used in symbol data(default data name) 28 label_name : string 29 the label name used in symbol softmax_label(default label name) 30 """ 31 32 def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117), 33 data_name="data", label_name="softmax_label", p=None): 34 super(FileIter, self).__init__() 35 36 self.fac = p.fac #这里的P是自己定义的config 37 self.root_dir = root_dir 38 self.flist_name = os.path.join(self.root_dir, flist_name) 39 self.mean = np.array(rgb_mean) # (R, G, B) 40 self.data_name = data_name 41 self.label_name = label_name 42 self.batch_size = p.batch_size 43 self.random_crop = p.random_crop 44 self.random_flip = p.random_flip 45 self.random_color = p.random_color 46 self.random_scale = p.random_scale 47 self.output_size = p.output_size 48 self.color_aug_range = p.color_aug_range 49 self.use_rnn = p.use_rnn 50 self.num_hidden = p.num_hidden 51 if self.use_rnn: 52 self.init_h_name = 'init_h' 53 self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden)) 54 self.cursor = -1 55 56 self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1])) 57 self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac)) 58 self.data_list = [] 59 self.label_list = [] 60 self.order = [] 61 self.dict = {} 62 lines = file(self.flist_name).read().splitlines() 63 cnt = 0 64 for line in lines: #读取lst,为后面读取图片做好准备 65 _, data_img_name, label_img_name = line.strip('\n').split("\t") 66 self.data_list.append(data_img_name) 67 self.label_list.append(label_img_name) 68 self.order.append(cnt) 69 cnt += 1 70 self.num_data = cnt 71 self._shuffle() 72 73 def _shuffle(self): 74 random.shuffle(self.order) 75 76 def _read_img(self, img_name, label_name): 77 # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存 78 if os.path.join(self.root_dir, img_name) in self.dict: 79 img = self.dict[os.path.join(self.root_dir, img_name)] 80 else: 81 img = cv2.imread(os.path.join(self.root_dir, img_name)) 82 self.dict[os.path.join(self.root_dir, img_name)] = img 83 84 if os.path.join(self.root_dir, label_name) in self.dict: 85 label = self.dict[os.path.join(self.root_dir, label_name)] 86 else: 87 label = cv2.imread(os.path.join(self.root_dir, label_name),0) 88 self.dict[os.path.join(self.root_dir, label_name)] = label 89 90 91 # 下面是读取图片后的一系统预处理工作 92 if self.random_flip: 93 flip = random.randint(0, 1) 94 if flip == 1: 95 img = cv2.flip(img, 1) 96 label = cv2.flip(label, 1) 97 # scale jittering 98 scale = random.uniform(self.random_scale[0], self.random_scale[1]) 99 new_width = int(img.shape[1] * scale) # 680 100 new_height = int(img.shape[0] * scale) # new_width * img.size[1] / img.size[0] 101 img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST) 102 label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST) 103 #img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST) 104 #label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST) 105 if self.random_crop: 106 start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1) 107 start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1) 108 img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :] 109 label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]] 110 if self.random_color: 111 img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 112 hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0]) 113 sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1]) 114 val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2]) 115 img = np.array(img, dtype=np.float32) 116 img[..., 0] += hue 117 img[..., 1] += sat 118 img[..., 2] += val 119 img[..., 0] = np.clip(img[..., 0], 0, 255) 120 img[..., 1] = np.clip(img[..., 1], 0, 255) 121 img[..., 2] = np.clip(img[..., 2], 0, 255) 122 img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR) 123 is_rgb = True 124 #cv2.imshow('main', img) 125 #cv2.waitKey() 126 #cv2.imshow('maain', label) 127 #cv2.waitKey() 128 img = np.array(img, dtype=np.float32) # (h, w, c) 129 reshaped_mean = self.mean.reshape(1, 1, 3) 130 img = img - reshaped_mean 131 img[:, :, :] = img[:, :, [2, 1, 0]] 132 img = img.transpose(2, 0, 1) 133 # img = np.expand_dims(img, axis=0) # (1, c, h, w) 134 135 label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac) 136 label_zoomed = label_zoomed.astype('uint8') 137 return (img, label_zoomed) 138 139 @property 140 def provide_data(self): 141 """The name and shape of data provided by this iterator""" 142 if self.use_rnn: 143 return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])), 144 (self.init_h_name, (self.batch_size, self.num_hidden))] 145 else: 146 return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))] 147 148 @property 149 def provide_label(self): 150 """The name and shape of label provided by this iterator""" 151 return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))] 152 153 def get_batch_size(self): 154 return self.batch_size 155 156 def reset(self): 157 self.cursor = -self.batch_size 158 self._shuffle() 159 160 def iter_next(self): 161 self.cursor += self.batch_size 162 return self.cursor < self.num_data 163 164 def _getpad(self): 165 if self.cursor + self.batch_size > self.num_data: 166 return self.cursor + self.batch_size - self.num_data 167 else: 168 return 0 169 170 def _getdata(self): 171 """Load data from underlying arrays, internal use only""" 172 assert(self.cursor < self.num_data), "DataIter needs reset." 173 data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1])) 174 label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac)) 175 if self.cursor + self.batch_size <= self.num_data: 176 for i in range(self.batch_size): 177 idx = self.order[self.cursor + i] 178 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx]) 179 data[i] = data_ 180 label[i] = label_ 181 else: 182 for i in range(self.num_data - self.cursor): 183 idx = self.order[self.cursor + i] 184 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx]) 185 data[i] = data_ 186 label[i] = label_ 187 pad = self.batch_size - self.num_data + self.cursor 188 #for i in pad: 189 for i in range(pad): 190 idx = self.order[i] 191 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx]) 192 data[i + self.num_data - self.cursor] = data_ 193 label[i + self.num_data - self.cursor] = label_ 194 return mx.nd.array(data), mx.nd.array(label) 195 196 def next(self): 197 """return one dict which contains "data" and "label" """ 198 if self.iter_next(): 199 data, label = self._getdata() 200 data = [data, self.init_h] if self.use_rnn else [data] 201 label = [label] 202 return DataBatch(data=data, label=label, 203 pad=self._getpad(), index=None, 204 provide_data=self.provide_data, 205 provide_label=self.provide_label) 206 else: 207 raise StopIteration
到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?
其实也很简单,只需要修改几个地方:
1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。
2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。
值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。
总之......That's all~~~~