keras自定义多输入损失函数

1. keras loss函数常规用法

# 输入仅为目标真值、预测值时
model.compile(loss='mse', optimizer =..., metrics = ...)   

 

2. keras loss自定义损失

def charbonnier(I_x, I_y, I_t, U, V, e)
    def loss_fun(y_true, y_pred):
        #定义loss,这里不必非使用y_true,y_pred
        loss = K.sqrt(K.pow((U*I_x + V*I_y + I_t), 2) + e)
        return K.sum(loss)
    return loss_fun

# y_true, y_pred不必传递
model.compile(loss=charbonnier(I_x, I_y, I_t, U, V, e), optimizer =..., metrics = ...)   

 

3. keras loss自定义损失的读取

对含有自定义损失函数的在读取时,需在load_model指定对应的损失函数名、参数

def dice_loss(smooth):
    def dice(y_true, y_pred):
        # print("y_true_f",y_true.shape)
        # print("y_pred_f",y_pred.shape) 
        return 1-dice_coef(y_true, y_pred, smooth)
    return dice

 model_dice=dice_loss(smooth=1e-5)
 model.compile(optimizer = Nadam(lr = 2e-4), loss = model_dice, metrics = ['accuracy'])

# 注意custom_objects中的key为dice, value为dice_loss
model=load_model("vnet_s_extend_epoch110.hdf5",custom_objects={'dice':dice_loss(1e-5)})

参考内容

keras自定义Loss

Keras Custom loss function to pass arguments other than y_true and y_pred

https://www.lmlphp.com/user/151109/article/item/2732980/

Custom Keras Loss (which does NOT have the form f(y_true, y_pred))

keras训练和加载自定义的损失函数

 

posted @   猴子吃桃_Q  阅读(300)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示