Keras的一些功能函数

 

1、模型的信息提取

1 # 节点信息提取
2 config = model.get_config()  # 把model中的信息,solver.prototxt和train.prototxt信息提取出来
3 model = Model.from_config(config)  # 还回去
4 # or, for Sequential:
5 model = Sequential.from_config(config) # 重构一个新的Model模型,用去其他训练,fine-tuning比较好用

 

2、模型概况查询

# 1、模型概括打印
model.summary()

# 2、返回代表模型的JSON字符串,仅包含网络结构,不包含权值。可以从JSON字符串中重构原模型:
from models import model_from_json

json_string = model.to_json()
model = model_from_json(json_string)

# 3、model.to_yaml:与model.to_json类似,同样可以从产生的YAML字符串中重构模型
from models import model_from_yaml

yaml_string = model.to_yaml()
model = model_from_yaml(yaml_string)

# 4、权重获取
model.get_layer()      #依据层名或下标获得层对象
model.get_weights()    #返回模型权重张量的列表,类型为numpy array
model.set_weights()    #从numpy array里将权重载入给模型,要求数组具有与model.get_weights()相同的形状。

# 查看model中Layer的信息
model.layers 查看layer信息

 

3、模型保存与加载

model.save_weights(filepath)
# 将模型权重保存到指定路径,文件类型是HDF5(后缀是.h5)

model.load_weights(filepath, by_name=False)
# 从HDF5文件中加载权重到当前模型中, 默认情况下模型的结构将保持不变。
# 如果想将权重载入不同的模型(有些层相同)中,则设置by_name=True,只有名字匹配的层才会载入权重

 

4、在keras中设定GPU的大小

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3
set_session(tf.Session(config=config))

 

5、训练与保存模型

filepath = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
# fit model
model.fit(x, y, epochs=20, verbose=2, callbacks=[checkpoint], validation_data=(x, y))

 

6、在keras中使用tensorboard

RUN = RUN + 1 if 'RUN' in locals() else 1   # locals() 函数会以字典类型返回当前位置的全部局部变量。

    LOG_DIR = model_save_path + '/training_logs/run{}'.format(RUN)
    LOG_FILE_PATH = LOG_DIR + '/checkpoint-{epoch:02d}-{val_loss:.4f}.hdf5'   # 模型Log文件以及.h5模型文件存放地址

    tensorboard = TensorBoard(log_dir=LOG_DIR, write_images=True)
    checkpoint = ModelCheckpoint(filepath=LOG_FILE_PATH, monitor='val_loss', verbose=1, save_best_only=True)
    early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)

    history = model.fit_generator(generator=gen.generate(True), steps_per_epoch=int(gen.train_batches / 4),
                                  validation_data=gen.generate(False), validation_steps=int(gen.val_batches / 4),
                                  epochs=EPOCHS, verbose=1, callbacks=[tensorboard, checkpoint, early_stopping])

 

posted @ 2019-04-17 09:46  ylxn  阅读(246)  评论(0编辑  收藏  举报