Keras模型多输入-多输出设计思路

1.多输入、多输出

模型某一层接收多输入数据,以实现共享该层参数的目的。如对title和desc做文本分类,两类可以共享一个embedding数据,进而获取某种关联特征,示例代码如下:

title = Input(shape=(30,),name="title")
desc = Input(shape=(200,),name="desc")
# title和desc 共享 mebedding layer
embedding = Embedding(3000, 128)

title_embedd = embedding(title)
desc_embedd = embedding(desc)
title_lstm = LSTM(128)(title_embedd)
desc_lstm = LSTM(128)(desc_embedd)

out_title = Dense(1,activation="sigmoid",name="out_title")(title_lstm)
out_desc = Dense(1,activation="sigmoid",name="out_desc")(desc_lstm)
model = Model(inputs=[title,desc],outputs=[out_title,out_desc])
keras.utils.plot_model(model, show_shapes=True)

打印model:

 

2.不同输出设置不同的类型loss和weights

# model compile
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
# 输入和输出有多个,喂数据时整理成list形式,对应好
model.fit([title_input, desc_input],[title_out, desc_out])
# 不同的输出设置不同的loss和权重
model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, optimizer="adam", metrics=["accuracy"])
model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, loss_weights={"out_title":0.3,"out_desc":0.8},optimizer="adam",metrics=["accuracy"])

注:根据输出的名称对应设置类型,Keras这种思路无处不在

 Keras API:https://keras.io/api/models/model_training_apis/

posted @ 2022-02-11 11:28  今夜无风  阅读(877)  评论(0编辑  收藏  举报