keras 学习笔记(二) ——— data_generator

data_generator

每次输出一个batch,基于keras.utils.Sequence

Base object for fitting to a sequence of data, such as a dataset.

Every Sequence must implement the __getitem__ and the __len__ methods. If you want to modify your dataset between epochs you may implement on_epoch_end. The method __getitem__ should return a complete batch.

Notes

Sequence are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.

   Sequence example: https://keras.io/utils/#sequence

#!/usr/bin/env python
# coding: utf-8



from keras.utils import Sequence
import numpy as np
from keras.preprocessing import image
from skimage.io import imread

class My_Custom_Generator(Sequence) :
    def __init__(self, image_filenames, labels, batch_size) :
        self.image_filenames = image_filenames
        self.labels = labels
        self.batch_size = batch_size
    def __len__(self) :
        return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)

    def __getitem__(self, idx) :
        batch_y = self.labels[idx * self.batch_size : (idx+1) * self.batch_size]
        batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
        batch_seq  = []  #batch_seq
        for x in batch_x:  #len(x) =16
            seq_img = []
            for img in x: #len(item) =25
                seq_img.append(image.img_to_array(imread(img)))
            seq_x = np.array([seq_img])
            batch_seq.append(seq_img)
        batch_seq_list = np.array(batch_seq)
        return batch_seq_list, np.array(batch_y)

 

两种将数据输出为numpy.array的方法

 

通过list转为numpy.array

速度快,list转array过程需要注意数据维度变化

''' list
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq  = []  #batch_seq
for x in batch_x:  #len(x) =16
    seq_img = []
    for img in x: #len(item) =25
        seq_img.append(image.img_to_array(imread(img)))
    seq_x = np.array([seq_img])
    batch_seq.append(seq_img)
batch_seq_list = np.array(batch_seq)
'''

 

利用np.empty

速度慢,开始前确定batch维度即可

'''numpy
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq  = np.empty((0,25,224,224,3),float)
for x in batch_x:  #len(x) =16
    seq_batch = np.empty((0,224,224,3),float)
    for item in x: #len(item) =25
        seq_batch = np.append(seq_batch, np.expand_dims(image.img_to_array(imread(item)), axis=0), axis = 0) 
    batch_seq2 = np.append(batch_seq, np.expand_dims((seq_batch), axis=0), axis = 0)
'''

  

 

 
posted @ 2019-07-19 20:52  Pent°  阅读(1421)  评论(0编辑  收藏  举报