keras中激活函数自定义(以mish函数为列)
若使用keras框架直接编辑函数调用会导致编译错误。因此,有2种方法可以实现keras的调用,其一使用lamda函数调用,
其二使用继承Layer层调用(如下代码)。如果使用继承layer层调用,那你可以将你想要实现的方式,通过call函数编辑,而
call函数的参数一般为输入特征变量[batch,h,w,c],具体实现如下代码:
class Mish(Layer):
'''
Mish Activation Function.
.. math::
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
tanh=(1 - e^{-2x})/(1 + e^{-2x})
Shape:
- Input: Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
- Output: Same shape as the input.
Examples:
>>> X_input = Input(input_shape)
>>> X = Mish()(X_input)
'''
def __init__(self, **kwargs):
super(Mish, self).__init__(**kwargs)
self.supports_masking = True
def call(self, inputs):
return inputs * K.tanh(K.softplus(inputs))
def get_config(self):
config = super(Mish, self).get_config()
return config
def compute_output_shape(self, input_shape):
'''
compute_output_shape(self, input_shape):为了能让Keras内部shape的匹配检查通过,
这里需要重写compute_output_shape方法去覆盖父类中的同名方法,来保证输出shape是正确的。
父类Layer中的compute_output_shape方法直接返回的是input_shape这明显是不对的,
所以需要我们重写这个方法。所以这个方法也是4个要实现的基本方法之一。
'''
return input_shape
有了mish激活函数,该如何调呢?以下代码简单演示其调用方式:
cov1=conv2d(卷积参数)(input) # 将输入input进行卷积操作
Mish()(cov1) # 将卷积结果输入定义的激活类中,实现mish激活