tensorflow2.0 使用fit实现复杂自定义loss函数

复制代码
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import layers as KL
from tensorflow.python.keras import models as KM
import numpy as np

class ComplicatedLoss(KL.Layer):
    def __init__(self, **kwargs):
        super(ComplicatedLoss, self).__init__(**kwargs)
    def call(self, inputs, **kwargs):
            # 父类KL.Layer的call方法明确要求inputs为一个tensor,或者包含多个tensor的列表/元组        这里为多个tensor组成的列表        """        # 解包入参
         y_true, y_weight, y_pred = inputs        # 复杂的损失函数
         bce_loss = K.binary_crossentropy(y_true, y_pred)
         wbce_loss = K.mean(bce_loss * y_weight)        # 把自定义的loss添加进层使其生效
         self.add_loss(wbce_loss, inputs=True)        # 将每个loss加入metric方便在KERAS的进度条上实时追踪
         self.add_metric(wbce_loss, aggregation="mean", name="wbce_loss")
         self.add_metric(bce_loss, aggregation="mean", name="bce_loss")
         return wbce_loss

def my_model():
# input layers
    input_img = KL.Input([32, 32, 3], name="img1")
    input_lbl = KL.Input([32, 32, 1], name="lbl")
    input_weight = KL.Input([32, 32, 1], name="weight")
    predict = KL.Conv2D(2, [1, 1], padding="same")(input_img)
    my_loss = ComplicatedLoss()([input_lbl, input_weight, predict])
    model = KM.Model(inputs=[input_img, input_lbl, input_weight], outputs=[predict, my_loss])
    model.compile(optimizer="adam")
    return model

def get_fake_dataset():
    def map_fn(img, lbl, weight):
        inputs = {"img1": img, "lbl": lbl, "weight": weight}
        # inputs = [img, lbl, weight]
        targets = {}
        return inputs, targets
    fake_imgs = np.ones([100, 32, 32, 3])
    fake_lbls = np.ones([100, 32, 32, 1])
    fake_weights = np.zeros([100, 32, 32, 1])
    fake_dataset = tf.data.Dataset.from_tensor_slices((fake_imgs, fake_lbls, fake_weights)).map(map_fn).batch(10)
    return fake_dataset
if __name__ == '__main__':
    model = my_model()
    my_dataset = get_fake_dataset()
    model.fit(my_dataset,epochs=2)
复制代码

 

posted @   山…隹  阅读(467)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现
历史上的今天:
2019-05-25 docker 容器后台运行命令
2019-05-25 docker 基础命令1
点击右上角即可分享
微信分享提示