Keras同时有多个输出时损失函数计算方法和反向传播过程
来源:https://stackoverflow.com/questions/57149476/how-is-a-multiple-outputs-deep-learning-model-trained
Keras calculations are graph based and use only one optimizer.
The optimizer is also a part of the graph, and in its calculations it gets the gradients of the whole group of weights. (Not two groups of gradients, one for each output, but one group of gradients for the entire model).
Mathematically, it's not really complicated, you have a final loss function made of:
loss = (main_weight * main_loss) + (aux_weight * aux_loss) #you choose the weights in model.compile
All defined by you. Plus a series of other possible weights (sample weights, class weights, regularizer terms, etc.)
Where:
main_loss
is afunction_of(main_true_output_data, main_model_output)
aux_loss
is afunction_of(aux_true_output_data, aux_model_output)
And the gradients are just ∂(loss)/∂(weight_i)
for all weights.
Once the optimizer has the gradients, it performs its optimization step once.
Questions:
how are the auxiliary branch weights updated as it is not connected directly to the main output?
- You have two output datasets. One dataset for
main_output
and another dataset foraux_output
. You must pass them tofit
inmodel.fit(inputs, [main_y, aux_y], ...)
- You also have two loss functions, one for each, where
main_loss
takesmain_y
andmain_out
; andaux_loss
takexaux_y
andaux_out
. - The two losses are summed:
loss = (main_weight * main_loss) + (aux_weight * aux_loss)
- The gradients are calculated for the function
loss
once, and this function connects to the entire model.- The
aux
term will affectlstm_1
andembedding_1
in backpropagation. - Consequently, in the next forward pass (after weights are updated) it will end up influencing the main branch. (If it will be better or worse only depends on whether the aux output is useful or not)
- The
Is the part of the network which is between the root of the auxiliary branch and the main output concerned by the the weighting of the loss? Or the weighting influences only the part of the network that is connected to the auxiliary output?
The weights are plain mathematics. You will define them in compile
:
model.compile(optimizer=one_optimizer,
#you choose each loss
loss={'main_output':main_loss, 'aux_output':aux_loss},
#you choose each weight
loss_weights={'main_output': main_weight, 'aux_output': aux_weight},
metrics = ...)
And the loss function will use them in loss = (weight1 * loss1) + (weight2 * loss2)
.
The rest is the mathematical calculation of ∂(loss)/∂(weight_i)
for each weight.