深度学习-卷积神经网络-keras生成器训练模型-51

1. 模型的定义

from keras.applications.inception_v3 import InceptionV3
from keras.models import Model
from keras.layers import Dense, GlobalAvgPool2D
from keras.optimizers import RMSprop

"""
在AlexNet及其之前的大抵上所有的基于神经网络的机器学习算法都要在卷积层之后添加上
全连接层来进行特征的向量化,此外出于神经网络黑盒子的考虑,有时设计几个全连接网络还可以
提升卷积神经网络的分类性能,一度成为神经网络使用的标配。
但是,我们同时也注意到,全连接层有一个非常致命的弱点就是参数量过大,
特别是与最后一个卷积层相连的全连接层。一方面增加了Training以及testing的计算量,
降低了速度;另外一方面参数量过大容易过拟合。虽然使用了类似dropout等手段去处理,
但是毕竟dropout是hyper-parameter, 不够优美也不好实践。

那么我们有没有办法将其替代呢?当然有,就是GAP(Global Average Pooling)。
我们要明确以下,全连接层将卷积层展开成向量之后不还是要针对每个feature map进行分类吗,
GAP的思路就是将上述两个过程合二为一,一起做了
"""


def Create_Invep(classes=2):
    """
    加载了预训练权重文件,然后在模型的顶部添加了全局平均池化层和两个全连接层,最后编译模型并保存模型文件。
    这个模型可以用于分类问题,其中classes参数指定了分类的数量。在主函数中调用Create_Invep函数创建模型并打印模型的结构信息。
    :param classes:
    :return:
    """
    base_model = InceptionV3(weights='imagenet', include_top=False)
    # base_model = InceptionV3(weights=r'D:\05-learning\15-learning-ai\02_cancer_reg\InceptionV3.h5', include_top=False)
    x = base_model.output
    x = GlobalAvgPool2D()(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=predictions)
    for layer in base_model.layers:
        layer.trainable = True
    # decay每次更新后 学习率的衰减值
    model.compile(optimizer=RMSprop(lr=0.001, decay=0.9, epsilon=0.1),
                  loss='categorical_crossentropy', metrics=['accuracy'])
    return model


if __name__ == '__main__':
    model = Create_Invep()
    model.summary()

2. 图片batch generator

from keras.preprocessing.image import ImageDataGenerator

batch_size = 5
width, height = 299, 299

'''
rotation_range:整数,数据提升时图片随机转动的角度。随机选择图片的角度,是一个0~180的度数,取值为0~180。
rescale: 值将在执行其他处理前乘到整个图像上,我们的图像在RGB通道都是0~255的整数,这样的操作可能使图像的值过高或过低,所以我们将这个值定为0~1之间的数。
vertical_flip:布尔值,进行随机竖直翻转。
horizontal_flip:布尔值,进行随机水平翻转。随机的对图片进行水平翻转,这个参数适用于水平翻转不影响图片语义的时候。
preprocessing_function: 将被应用于每个输入的函数。该函数将在任何其他修改之前运行。该函数接受一个参数,为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array
width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
height_shift_range:浮点数,图片高度的某个比例,数据提升时图片随机竖直偏移的幅度。 
height_shift_range和width_shift_range是用来指定水平和竖直方向随机移动的程度,这是两个0~1之间的比例。
'''


def train_data(train_data_dir='data/train'):
    train_datagen = ImageDataGenerator(
        rescale=1. / 255,
        rotation_range=15,
        shear_range=0.5,
        zoom_range=0.2,
        width_shift_range=0.3,
        height_shift_range=0.3,
        horizontal_flip=True,
        vertical_flip=True
    )
    # flow_from_directory(directory) 生成一个图像batch的生成器,
    # class_mode: 值为"categorical", "binary".用于计算分类正确率或调用
    train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        target_size=(width, height),
        batch_size=batch_size,
        class_mode='categorical'
    )

    return train_generator


def valid_data(valid_data_dir='data/validation'):
    valid_datagen = ImageDataGenerator(rescale=1. / 255)
    valid_generator = valid_datagen.flow_from_directory(
        valid_data_dir,
        target_size=(width, height),
        batch_size=batch_size,
        class_mode='categorical'
    )
    return valid_generator

3. 模型训练

# python3中已经支持精确算法
from __future__ import division

from InceptionV3 import Create_Invep
from Generator import train_data
from Generator import valid_data

from keras.callbacks import TensorBoard


def train(steps_per_epoch=326, validation_steps=67, epochs=50):
    train_generator = train_data()
    valid_generator = valid_data()

    visualization = TensorBoard(log_dir='./logs', write_graph=True)

    model = Create_Invep()
    """
    直接将数据集全部装进显卡中,如果处理大型的数据集(例如图片的尺寸很大)或者是网络很深很宽,可能会造成显存不足
    这个情况会常遇到,解决的方法就是分块装入
    keras 默认情况下使用fit() 就是全部装入
    fit_generator方法就会已自己手写的方式用yield 逐块 装入
    """
    model.fit(train_generator,
              steps_per_epochs=steps_per_epoch,
              epochs=epochs,
              validation_data=valid_generator,
              validation_steps=validation_steps,
              verbose=1,
              callbacks=[visualization])

    model.save('InceptionV3.h5')


if __name__ == '__main__':
    train()

posted @ 2024-02-29 15:53  jack-chen666  阅读(9)  评论(0编辑  收藏  举报