Keras模型多输入-多输出设计思路
1.多输入、多输出
模型某一层接收多输入数据,以实现共享该层参数的目的。如对title和desc做文本分类,两类可以共享一个embedding数据,进而获取某种关联特征,示例代码如下:
title = Input(shape=(30,),name="title") desc = Input(shape=(200,),name="desc") # title和desc 共享 mebedding layer embedding = Embedding(3000, 128) title_embedd = embedding(title) desc_embedd = embedding(desc) title_lstm = LSTM(128)(title_embedd) desc_lstm = LSTM(128)(desc_embedd) out_title = Dense(1,activation="sigmoid",name="out_title")(title_lstm) out_desc = Dense(1,activation="sigmoid",name="out_desc")(desc_lstm) model = Model(inputs=[title,desc],outputs=[out_title,out_desc]) keras.utils.plot_model(model, show_shapes=True)
打印model:
2.不同输出设置不同的类型loss和weights
# model compile model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]) # 输入和输出有多个,喂数据时整理成list形式,对应好 model.fit([title_input, desc_input],[title_out, desc_out]) # 不同的输出设置不同的loss和权重 model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, optimizer="adam", metrics=["accuracy"]) model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, loss_weights={"out_title":0.3,"out_desc":0.8},optimizer="adam",metrics=["accuracy"])
注:根据输出的名称对应设置类型,Keras这种思路无处不在
Keras API:https://keras.io/api/models/model_training_apis/
时刻记着自己要成为什么样的人!