keras创建自己训练代码

 

由于某个github只开源了测试代码,所以训练代码需要自己写

版本keras,tensorflow

 

# import src.modelCore as modelCore
from src.modelCore import create_model
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import keras
from keras.callbacks import ModelCheckpoint
config = tf.ConfigProto()
config.gpu_options.allow_growth = True      # TensorFlow按需分配显存
config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 指定显存分配比例
keras.backend.tensorflow_backend.set_session(tf.Session(config=config))


# 加载模型
def load_pretrain_model_by_index(pretrain_index):
    if pretrain_index == 4:
        IMC_model_idx, freeze_featex, window_size_list = 2, False, [7, 15, 31]
    else:
        IMC_model_idx, freeze_featex, window_size_list = pretrain_index, False, [7, 15, 31, 63]
    single_gpu_model = create_model(IMC_model_idx, freeze_featex, window_size_list)
    # weight_file = "{}/ManTraNet_Ptrain{}.h5".format(model_dir, pretrain_index )
    # assert os.path.isfile(weight_file), "ERROR: fail to locate the pretrained weight file"
    # single_gpu_model.load_weights( weight_file )
    return single_gpu_model


def trainGenerator(batch_size, train_path, image_folder, mask_folder, aug_dict, image_color_mode="rgb",
                   mask_color_mode="grayscale", image_save_prefix="image", mask_save_prefix="mask",
                   flag_multi_class=False, num_class=2, save_to_dir=None, target_size=(256, 256), seed=1):
    '''
    can generate image and mask at the same time
    use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
    if you want to visualize the results of generator, set save_to_dir = "your path"
    '''
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)
    image_generator = image_datagen.flow_from_directory(
        train_path,  # 训练数据文件夹路径
        classes=[image_folder],  # 类别文件夹,对哪一个类进行增强
        class_mode=None,  # 不返回标签
        color_mode=image_color_mode,  # 灰度,单通道模式
        # target_size=target_size,  # 转换后的目标图片大小
        batch_size=batch_size,  # 每次产生的(进行转换的)图片张数
        save_to_dir=save_to_dir,  # 保存的图片路径
        save_prefix=image_save_prefix,  # 生成图片的前缀,仅当提供save_to_dir时有效
        seed=seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes=[mask_folder],
        class_mode=None,
        color_mode=mask_color_mode,
        # target_size=target_size,
        batch_size=batch_size,
        save_to_dir=save_to_dir,
        save_prefix=mask_save_prefix,
        seed=seed)
    train_generator = zip(image_generator, mask_generator)  # 组合成一个生成器
    for (img, mask) in train_generator:
        # 由于batch是2,所以一次返回两张,即img是一个2张灰度图片的数组,[2,256,256]
        # img, mask = adjustData(img, mask, flag_multi_class, num_class)  # 返回的img依旧是[2,256,256]
        yield (img, mask)


manTraNet = load_pretrain_model_by_index(4)
sgd = SGD(0.01, 0, 1e-6)
manTraNet.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
train_path = r"C:\Users\DNY-006\Desktop\s2_data\s2_data\data"
mask_path = r"C:\Users\DNY-006\Desktop\s2_data\s2_data\data\train_mask"
# img_train, mask_train = geneTrainNpy(train_path, mask_path)
data_gen_args = dict(rotation_range=0.2, #整数。随机旋转的度数范围。
                    width_shift_range=0.05, #浮点数、一维数组或整数
                    height_shift_range=0.05, #浮点数。剪切强度(以弧度逆时针方向剪切角度)。
                    shear_range=0.05,
                    zoom_range=0.05, #浮点数 或 [lower, upper]。随机缩放范围
                    horizontal_flip=True,
                    fill_mode='nearest')
train_generator = trainGenerator(1,train_path,'1111','train_mask',data_gen_args,save_to_dir = None)
# 保存训练的模型参数到指定的文件夹,格式为.hdf5; 检测的值是'loss'使其更小。
model_checkpoint = ModelCheckpoint('ManTraNet_owndata.hdf5', monitor='loss',verbose=1, save_best_only=True)
# manTraNet.fit(img_train, mask_train, epochs=50, batch_size=32, shuffle=True, verbose=1, validation_split=0.3)
manTraNet.fit_generator(train_generator,steps_per_epoch=1000,epochs=60,callbacks=[model_checkpoint])#validation_data=validation_generator, validation_steps=200)# import src.modelCore as modelCore
from src.modelCore import create_model
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import keras
from keras.callbacks import ModelCheckpoint
config = tf.ConfigProto()
config.gpu_options.allow_growth = True      # TensorFlow按需分配显存
config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 指定显存分配比例
keras.backend.tensorflow_backend.set_session(tf.Session(config=config))


# 加载模型
def load_pretrain_model_by_index(pretrain_index):
    if pretrain_index == 4:
        IMC_model_idx, freeze_featex, window_size_list = 2, False, [7, 15, 31]
    else:
        IMC_model_idx, freeze_featex, window_size_list = pretrain_index, False, [7, 15, 31, 63]
    single_gpu_model = create_model(IMC_model_idx, freeze_featex, window_size_list)
    # weight_file = "{}/ManTraNet_Ptrain{}.h5".format(model_dir, pretrain_index )
    # assert os.path.isfile(weight_file), "ERROR: fail to locate the pretrained weight file"
    # single_gpu_model.load_weights( weight_file )
    return single_gpu_model


def trainGenerator(batch_size, train_path, image_folder, mask_folder, aug_dict, image_color_mode="rgb",
                   mask_color_mode="grayscale", image_save_prefix="image", mask_save_prefix="mask",
                   flag_multi_class=False, num_class=2, save_to_dir=None, target_size=(256, 256), seed=1):
    '''
    can generate image and mask at the same time
    use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
    if you want to visualize the results of generator, set save_to_dir = "your path"
    '''
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)
    image_generator = image_datagen.flow_from_directory(
        train_path,  # 训练数据文件夹路径
        classes=[image_folder],  # 类别文件夹,对哪一个类进行增强
        class_mode=None,  # 不返回标签
        color_mode=image_color_mode,  # 灰度,单通道模式
        # target_size=target_size,  # 转换后的目标图片大小
        batch_size=batch_size,  # 每次产生的(进行转换的)图片张数
        save_to_dir=save_to_dir,  # 保存的图片路径
        save_prefix=image_save_prefix,  # 生成图片的前缀,仅当提供save_to_dir时有效
        seed=seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes=[mask_folder],
        class_mode=None,
        color_mode=mask_color_mode,
        # target_size=target_size,
        batch_size=batch_size,
        save_to_dir=save_to_dir,
        save_prefix=mask_save_prefix,
        seed=seed)
    train_generator = zip(image_generator, mask_generator)  # 组合成一个生成器
    for (img, mask) in train_generator:
        # 由于batch是2,所以一次返回两张,即img是一个2张灰度图片的数组,[2,256,256]
        # img, mask = adjustData(img, mask, flag_multi_class, num_class)  # 返回的img依旧是[2,256,256]
        yield (img, mask)


manTraNet = load_pretrain_model_by_index(4)
sgd = SGD(0.01, 0, 1e-6)
manTraNet.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
train_path = r"C:\Users\DNY-006\Desktop\s2_data\s2_data\data"
mask_path = r"C:\Users\DNY-006\Desktop\s2_data\s2_data\data\train_mask"
# img_train, mask_train = geneTrainNpy(train_path, mask_path)
data_gen_args = dict(rotation_range=0.2, #整数。随机旋转的度数范围。
                    width_shift_range=0.05, #浮点数、一维数组或整数
                    height_shift_range=0.05, #浮点数。剪切强度(以弧度逆时针方向剪切角度)。
                    shear_range=0.05,
                    zoom_range=0.05, #浮点数 或 [lower, upper]。随机缩放范围
                    horizontal_flip=True,
                    fill_mode='nearest')
train_generator = trainGenerator(1,train_path,'1111','train_mask',data_gen_args,save_to_dir = None)
# 保存训练的模型参数到指定的文件夹,格式为.hdf5; 检测的值是'loss'使其更小。
model_checkpoint = ModelCheckpoint('ManTraNet_owndata.hdf5', monitor='loss',verbose=1, save_best_only=True)
# manTraNet.fit(img_train, mask_train, epochs=50, batch_size=32, shuffle=True, verbose=1, validation_split=0.3)
manTraNet.fit_generator(train_generator,steps_per_epoch=1000,epochs=60,callbacks=[model_checkpoint])#validation_data=validation_generator, validation_steps=200)

 

 

参考

https://blog.csdn.net/Xnion/article/details/105797671

posted @ 2020-10-19 14:40  剑峰随心  阅读(329)  评论(0编辑  收藏  举报