一切过往,皆为序章,一切未知,皆为终章。

tf2 自定義loss加載報錯

問題描述

ValueError: Unknown loss function: bes_loss

問題場景

  • 訓練
margin = 0.6
theta = lambda t : (K.sign(t) + 1.) / 2
def bes_loss(y_true, y_pred):
    return - (1 - theta(y_true - margin) * theta(y_pred - margin)
            - theta(1 - margin - y_true) * theta(1 - margin - y_pred)
         ) * (y_true * K.log(y_pred + 1e-8) + (1 - y_true) * K.log(1 - y_pred + 1e-8))
···
model.compile(tf.optimizers.Adam(), loss=bes_loss,metrics=['accuracy'])
  • 測試
model = load_model(config.model_path, custom_objects={'bes_loss':bes_loss})

這樣的加載方式就會出現報錯,如問題描述

問題解決

model = load_model(config.model_path, custom_objects={'bes_loss':bes_loss}, compile = False)
model.compile(tf.optimizers.Adam(), loss=bes_loss, metrics=['accuracy'])

通過compile=False忽略加載錯誤報錯,然後再通過model.compile()加載模型的配置

posted @ 2022-04-06 16:22  爱吃帮帮糖  阅读(70)  评论(0编辑  收藏  举报