【638】keras 多输出模型【实战】
[Keras] [multiple inputs / outputs] ValueError: No data provided for "xx". Need data for each key...
1. model.compile
对于多输出模型而言,多出来一个字典的形式,通过 model.compile 里面包含的 loss、loss_weight,可以通过字典的形式设置,如下所示:
1 2 3 4 5 6 7 | model. compile (optimizer = 'rmsprop' , # 不同输出层对应的损失函数 loss = { 'outputs1' : 'binary_crossentropy' , 'outputs1' : 'binary_crossentropy' }, # 不同输出层对应的损失函数权重值 loss_weight = { 'outputs1' : 0.5 , 'outputs1' : 0.5 }) |
注意:字典的 key 值并不是随意设置的,需要前后一致,并且需要指定到具体的模型输出的名称以及数据生成器中的,否则是无法对应的。
2. 模型架构
因为是单一输入就不考虑输入的名称了,输出的名称需要对应,如下所示:
1 2 3 4 5 6 7 8 9 10 11 | # 输入 inputs = keras. Input (...) # 模型主体部分 ... # 输出 outputs1 = layers.Conv2D( 1 , 3 , activation = "sigmoid" , padding = "same" , name = "outputs1" )(x1) outputs2 = layers.Conv2D( 1 , 3 , activation = "sigmoid" , padding = "same" , name = "outputs2" )(x2) model = keras.Model(inputs, [outputs1, outputs2]) |
注意:outputs1 与 outputs2 里面的 name 值与上面对应
3. 数据生成器
数据生成器需要生成对应格式的数据,特别是通过 key 值来对应输出数据的 labels,如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | # 图像生成器,生成可以直接输入到模型中的 generator,返回值是 tuple class ImageGenerator(keras.utils.Sequence): """Helper to iterate over the data (as Numpy arrays).""" def __init__( self , batch_size, img_size, input_img_paths, target_img_paths_louding, target_img_paths_louti): ... def __len__( self ): return len ( self .input_img_paths) / / self .batch_size def __getitem__( self , idx): """Returns tuple (input, target) correspond to batch #idx.""" x = np.zeros(( self .batch_size,) + self .img_size + ( 3 ,), dtype = "float32" ) ... y1 = np.zeros(( self .batch_size,) + self .img_size + ( 1 ,), dtype = "float32" ) ... y2 = np.zeros(( self .batch_size,) + self .img_size + ( 1 ,), dtype = "float32" ) ... # 注意 key 值的对应 y = { 'outputs1' : y1, 'outputs2' : y2} return x, y |
总结:实际上多输出或者多输入与单输入单输出没有实质性的区别,就是在数据处理和衔接上面容易出现问题,只要将 key 值对应,无论是 fit 还是 fit_generator 都可以实现。
分类:
Python Study
, AI Related
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)