Tensorflow--创建模型,自定义loss
1、定义一个简单结构的函数模型,并自定义损失函数
1 from tensorflow.keras import Model, Input, layers, optimizers, losses 2 3 4 x = Input((5,)) 5 y = Input((1,)) 6 7 dense1 = layers.Dense(20, 'sigmoid')(x) 8 dense2 = layers.Dense(1)(dense1) 9 model1 = Model(x, dense2) 10 # loss = layers.Lambda(losses.binary_crossentropy, output_shape=(1,), 11 # arguments={'from_logits': True})([y, dense2]) 12 loss = losses.binary_crossentropy(y, model1.output) 13 14 #################### 可以自定义损失函数 15 # model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss', 16 # arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.5})( 17 # [*model_body.output, *y_true]) 18 19 model = Model([model1.input, y], loss) 20 21 model.compile( 22 optimizer=optimizers.Adam(lr=1e-3), 23 loss='binary_crossentropy', 24 metrics=['accuracy'] 25 ) 26 27 model.summary()
输出:
1 Model: "model_1" 2 __________________________________________________________________________________________________ 3 Layer (type) Output Shape Param # Connected to 4 ================================================================================================== 5 input_1 (InputLayer) [(None, 5)] 0 6 __________________________________________________________________________________________________ 7 dense (Dense) (None, 20) 120 input_1[0][0] 8 __________________________________________________________________________________________________ 9 dense_1 (Dense) (None, 1) 21 dense[0][0] 10 __________________________________________________________________________________________________ 11 tf_op_layer_clip_by_value_2/Min [(None, 1)] 0 dense_1[0][0] 12 __________________________________________________________________________________________________ 13 tf_op_layer_clip_by_value_2 (Te [(None, 1)] 0 tf_op_layer_clip_by_value_2/Minim 14 __________________________________________________________________________________________________ 15 tf_op_layer_sub_5 (TensorFlowOp [(None, 1)] 0 tf_op_layer_clip_by_value_2[0][0] 16 __________________________________________________________________________________________________ 17 input_2 (InputLayer) [(None, 1)] 0 18 __________________________________________________________________________________________________ 19 tf_op_layer_add_6 (TensorFlowOp [(None, 1)] 0 tf_op_layer_clip_by_value_2[0][0] 20 __________________________________________________________________________________________________ 21 tf_op_layer_add_7 (TensorFlowOp [(None, 1)] 0 tf_op_layer_sub_5[0][0] 22 __________________________________________________________________________________________________ 23 tf_op_layer_Log_4 (TensorFlowOp [(None, 1)] 0 tf_op_layer_add_6[0][0] 24 __________________________________________________________________________________________________ 25 tf_op_layer_sub_4 (TensorFlowOp [(None, 1)] 0 input_2[0][0] 26 __________________________________________________________________________________________________ 27 tf_op_layer_Log_5 (TensorFlowOp [(None, 1)] 0 tf_op_layer_add_7[0][0] 28 __________________________________________________________________________________________________ 29 tf_op_layer_mul_4 (TensorFlowOp [(None, 1)] 0 input_2[0][0] 30 tf_op_layer_Log_4[0][0] 31 __________________________________________________________________________________________________ 32 tf_op_layer_mul_5 (TensorFlowOp [(None, 1)] 0 tf_op_layer_sub_4[0][0] 33 tf_op_layer_Log_5[0][0] 34 __________________________________________________________________________________________________ 35 tf_op_layer_add_8 (TensorFlowOp [(None, 1)] 0 tf_op_layer_mul_4[0][0] 36 tf_op_layer_mul_5[0][0] 37 __________________________________________________________________________________________________ 38 tf_op_layer_Neg_2 (TensorFlowOp [(None, 1)] 0 tf_op_layer_add_8[0][0] 39 __________________________________________________________________________________________________ 40 tf_op_layer_Mean_2 (TensorFlowO [(None,)] 0 tf_op_layer_Neg_2[0][0] 41 ================================================================================================== 42 Total params: 141 43 Trainable params: 141 44 Non-trainable params: 0 45 __________________________________________________________________________________________________
清澈的爱,只为中国