keras 学习笔记(二) ——— data_generator
data_generator
每次输出一个batch,基于keras.utils.Sequence
Base object for fitting to a sequence of data, such as a dataset.
Every
Sequence
must implement the__getitem__
and the__len__
methods. If you want to modify your dataset between epochs you may implementon_epoch_end
. The method__getitem__
should return a complete batch.Notes
Sequence
are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.
Sequence example: https://keras.io/utils/#sequence
#!/usr/bin/env python
# coding: utf-8
from keras.utils import Sequence
import numpy as np
from keras.preprocessing import image
from skimage.io import imread
class My_Custom_Generator(Sequence) :
def __init__(self, image_filenames, labels, batch_size) :
self.image_filenames = image_filenames
self.labels = labels
self.batch_size = batch_size
def __len__(self) :
return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)
def __getitem__(self, idx) :
batch_y = self.labels[idx * self.batch_size : (idx+1) * self.batch_size]
batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
batch_seq = [] #batch_seq
for x in batch_x: #len(x) =16
seq_img = []
for img in x: #len(item) =25
seq_img.append(image.img_to_array(imread(img)))
seq_x = np.array([seq_img])
batch_seq.append(seq_img)
batch_seq_list = np.array(batch_seq)
return batch_seq_list, np.array(batch_y)
两种将数据输出为numpy.array的方法
通过list转为numpy.array
速度快,list转array过程需要注意数据维度变化
''' list
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq = [] #batch_seq
for x in batch_x: #len(x) =16
seq_img = []
for img in x: #len(item) =25
seq_img.append(image.img_to_array(imread(img)))
seq_x = np.array([seq_img])
batch_seq.append(seq_img)
batch_seq_list = np.array(batch_seq)
'''
利用np.empty
速度慢,开始前确定batch维度即可
'''numpy
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq = np.empty((0,25,224,224,3),float)
for x in batch_x: #len(x) =16
seq_batch = np.empty((0,224,224,3),float)
for item in x: #len(item) =25
seq_batch = np.append(seq_batch, np.expand_dims(image.img_to_array(imread(item)), axis=0), axis = 0)
batch_seq2 = np.append(batch_seq, np.expand_dims((seq_batch), axis=0), axis = 0)
'''