Keras的一些功能函数
1、模型的信息提取
1 # 节点信息提取 2 config = model.get_config() # 把model中的信息,solver.prototxt和train.prototxt信息提取出来 3 model = Model.from_config(config) # 还回去 4 # or, for Sequential: 5 model = Sequential.from_config(config) # 重构一个新的Model模型,用去其他训练,fine-tuning比较好用
2、模型概况查询
# 1、模型概括打印 model.summary() # 2、返回代表模型的JSON字符串,仅包含网络结构,不包含权值。可以从JSON字符串中重构原模型: from models import model_from_json json_string = model.to_json() model = model_from_json(json_string) # 3、model.to_yaml:与model.to_json类似,同样可以从产生的YAML字符串中重构模型 from models import model_from_yaml yaml_string = model.to_yaml() model = model_from_yaml(yaml_string) # 4、权重获取 model.get_layer() #依据层名或下标获得层对象 model.get_weights() #返回模型权重张量的列表,类型为numpy array model.set_weights() #从numpy array里将权重载入给模型,要求数组具有与model.get_weights()相同的形状。 # 查看model中Layer的信息 model.layers 查看layer信息
3、模型保存与加载
model.save_weights(filepath) # 将模型权重保存到指定路径,文件类型是HDF5(后缀是.h5) model.load_weights(filepath, by_name=False) # 从HDF5文件中加载权重到当前模型中, 默认情况下模型的结构将保持不变。 # 如果想将权重载入不同的模型(有些层相同)中,则设置by_name=True,只有名字匹配的层才会载入权重
4、在keras中设定GPU的大小
import tensorflow as tf from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.3 set_session(tf.Session(config=config))
5、训练与保存模型
filepath = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5' checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min') # fit model model.fit(x, y, epochs=20, verbose=2, callbacks=[checkpoint], validation_data=(x, y))
6、在keras中使用tensorboard
RUN = RUN + 1 if 'RUN' in locals() else 1 # locals() 函数会以字典类型返回当前位置的全部局部变量。 LOG_DIR = model_save_path + '/training_logs/run{}'.format(RUN) LOG_FILE_PATH = LOG_DIR + '/checkpoint-{epoch:02d}-{val_loss:.4f}.hdf5' # 模型Log文件以及.h5模型文件存放地址 tensorboard = TensorBoard(log_dir=LOG_DIR, write_images=True) checkpoint = ModelCheckpoint(filepath=LOG_FILE_PATH, monitor='val_loss', verbose=1, save_best_only=True) early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1) history = model.fit_generator(generator=gen.generate(True), steps_per_epoch=int(gen.train_batches / 4), validation_data=gen.generate(False), validation_steps=int(gen.val_batches / 4), epochs=EPOCHS, verbose=1, callbacks=[tensorboard, checkpoint, early_stopping])
谢谢!