将图片集使用迭代器,分batch输入到卷积神经网络中

问题来源:写了一个神经网络,需要用的测试集是本地图片。

第一次尝试解决:将本地图片读取,乱序,存成npz形式的文件。在第二次使用时,load这个npz文件。但这个方法针对图片量比较大的情况没办法应对,图片大小超过电脑内存。

第二次尝试解决:尝试将文件分批存储成npz形式,一次读取数据进行训练,但是在keras平台下难以训练。

第三次尝试解决:采用迭代器分批读取数据,使用fit_generator分批训练网络。这是需要注意的问题是,如果连续读入的都是一类文件,就会导致模型偏移。因为前半部分都是数据1,后半部分都是数据2.

解决思路如下:定义一个迭代器,在迭代器中,依次循环从数据文件夹中读取数据文件,这样可以保证一个batch中的每种类别的数据的个数是相同的。

  1 import os
  2 import cv2
  3 import keras
  4 import numpy as np
  5 def Generator(path, batch_size, data_num):
  6     i = 1
  7     data = []
  8     label = []
  9     while True:
 10         i = 1
 11         while i <= data_num:
 12             for j in range(1,3):
 13                 f = os.path.join('dataset', path, '%d' %j, '%d.jpg' %i)
 14                 im = cv2.imread(f)
 15                 im = cv2.resize(im, (227, 227))
 16                 data.append(im)
 17                 label.append(j-1)
 18             if(len(label) == batch_size):
 19                 data = np.array(data, dtype='float32')
 20                 label = keras.utils.to_categorical(label, 2)
 21                 yield data, label
 22                 data = []
 23                 label = []
 24             i += 1
 25 
 26 #*******************AlexNet_begin**************************
 27 from keras.models import Sequential
 28 from keras.layers.convolutional import Conv2D, MaxPooling2D
 29 from keras.layers import Dense, Flatten, Dropout, BatchNormalization
 30 
 31 batch_size = 4
 32 num_classes = 2
 33 epochs = 8
 34 
 35 model = Sequential()
 36 #第一层
 37 model.add(Conv2D(96, (11, 11),      #卷积核
 38                  strides=4,         #步长
 39                  input_shape=(227, 227, 3),
 40                  padding='valid',   #无填充
 41                  activation='relu'))
 42 model.add(MaxPooling2D(pool_size=(3, 3), strides=2 ))
 43 model.add(BatchNormalization())
 44 #第二层
 45 model.add(Conv2D(
 46     kernel_size=(27, 27),
 47     filters= 256,
 48     strides= 1,
 49     padding='same',
 50     activation='relu'
 51 ))
 52 model.add(MaxPooling2D(pool_size=(3, 3), strides=2 ))
 53 #第三层
 54 model.add(Conv2D(
 55     kernel_size=(3, 3),
 56     filters= 384,
 57     strides= 1,
 58     padding= 'same',
 59     activation= 'relu'
 60 ))
 61 #第四层
 62 model.add(Conv2D(
 63     kernel_size= (3, 3),
 64     filters= 384,
 65     strides= 1,
 66     padding= 'same',
 67     activation= 'relu'
 68 ))
 69 #第五层
 70 model.add(Conv2D(
 71     kernel_size= (3, 3),
 72     filters= 256,
 73     strides= 1,
 74     padding= 'same',
 75     activation= 'relu'
 76 ))
 77 model.add(MaxPooling2D(pool_size=(3, 3), strides=2 ))
 78 #第六层
 79 model.add(Flatten())
 80 model.add(Dense(128, activation='relu'))
 81 model.add(Dropout(0.5))
 82 #第七层
 83 model.add(Dense(128, activation='relu'))
 84 model.add(Dropout(0.5))
 85 #第八层
 86 model.add(Dense(2, activation='softmax'))
 87 #打印模型
 88 model.summary()
 89 
 90 model.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])
 91 #model.fit(x_train, y_train, epochs=5, validation_data=[x_test, y_test])
 92 print('fit......')
 93 model.fit_generator(Generator('train', 4, 100),
 94                         epochs=epochs,
 95                         steps_per_epoch=50,
 96                         workers=1)
 97 model.save('model.h5')
 98 # Evaluate the model with the metrics we defined earlier
 99 print('evaluate......')
100 loss1, accuracy1 = model.evaluate_generator(Generator('test', 4, 20), steps = 10)
101 print('test......')
102 loss2, accuracy2 = model.evaluate_generator(Generator('test2', 4, 12), steps = 6)
103 print('test1 loss: ', loss1)
104 print('test1 accuracy: ', accuracy1)
105 print('test2 loss: ', loss2)
106 print('test2 accuracy: ', accuracy2)
107 #*******************AlexNet_end****************************

数据文件结构如图所示:

 

参考资料链接:

1、https://github.com/keras-team/keras/issues/7729

2、https://blog.csdn.net/learning_tortosie/article/details/85243310

3、https://keras-cn.readthedocs.io/en/latest/models/model/

4、https://blog.csdn.net/shahuzi/article/details/81210557

5、https://blog.csdn.net/yideqianfenzhiyi/article/details/79197570

6、https://zhuanlan.zhihu.com/p/31558973

posted @ 2019-03-08 16:50  小橙7  阅读(1275)  评论(0编辑  收藏  举报