【516】keras 源码分析之 Dense
本文主要讲解一下 Dense 层的源码,Dense 层即最常用的全连接层,代码很简单,主要是重写了 build
与 call
方法,在我们自定义 Layer 时,也可以参考该层的实现。但是不需要这么复杂,只要写出必要的部分就可以了,参见下一篇博客。
1. Layer 类的相关说明
参考:TensorFlow函数:tf.layers.Layer —— W3Cschool TensorFlow 官方文档
参考:关于 Keras 网络层 —— keras 中文文档
基础层类。这是所有层都继承的类,实现了通用的基础结构功能。层是实现常见神经网络操作的类,例如卷积、批量规范等。这些操作需要管理变量、损失和更新,以及将 TensorFlow 操作应用于输入张量。用户只需实例化它,然后将其视为可调用的。
我们建议 Layer 的子代实现以下方法:
- __init__ ():在成员变量中保存配置
- build():当我们知道输入和 dtype 的形状时,从 __call__ 调用一次。应该有对 add_variable() 的调用,然后调用高级的 build() (设置为 self.built = True,这在用户想要在第一个 __call__ 之前手动调用 build() 时很好)。
- * call():确认 build() 已被调用一次后调用 __call__。实际上应该执行将层应用于输入张量的逻辑(应该作为第一个参数传入)。
2. Dense 源码解读
2.1 __init__ 函数重写
构造方法没什么好说的,就是一些简单的赋值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | from keras.layers import Layer class Dense(Layer): def __init__( self , units, activation = None , use_bias = True , kernel_initializer = 'glorot_uniform' , bias_initializer = 'zeros' , kernel_regularizer = None , bias_regularizer = None , activity_regularizer = None , kernel_constraint = None , bias_constraint = None , * * kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs[ 'input_shape' ] = (kwargs.pop( 'input_dim' ),) super (Dense, self ).__init__( * * kwargs) self .units = units self .activation = activations.get(activation) self .use_bias = use_bias self .kernel_initializer = initializers.get(kernel_initializer) self .bias_initializer = initializers.get(bias_initializer) self .kernel_regularizer = regularizers.get(kernel_regularizer) self .bias_regularizer = regularizers.get(bias_regularizer) self .activity_regularizer = regularizers.get(activity_regularizer) self .kernel_constraint = constraints.get(kernel_constraint) self .bias_constraint = constraints.get(bias_constraint) self .input_spec = InputSpec(min_ndim = 2 ) self .supports_masking = True |
2.2 build 函数重写
build 方法中定义了两个 Variable 即权重,最后把 built 参数置为 True。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | def build( self , input_shape): assert len (input_shape) > = 2 # 维度取 input_shape 的最后一维 # 正好进行后面的叉乘 input_dim = input_shape[ - 1 ] # 设置权重矩阵,维度为 (input_dim, self.units),用于叉乘 self .kernel = self .add_weight(shape = (input_dim, self .units), initializer = self .kernel_initializer, name = 'kernel' , regularizer = self .kernel_regularizer, constraint = self .kernel_constraint) if self .use_bias: # 设置偏置 self .bias = self .add_weight(shape = ( self .units,), initializer = self .bias_initializer, name = 'bias' , regularizer = self .bias_regularizer, constraint = self .bias_constraint) else : self .bias = None self .input_spec = InputSpec(min_ndim = 2 , axes = { - 1 : input_dim}) self .built = True |
2.3 call 函数重写
call 方法把输入值与 build 方法中定义的权重进行了点积的操作,然后与 build 中的偏移量进行相加,最后经过激活函数返回最终的输出结果。
1 2 3 4 5 6 7 8 | def call( self , inputs): # 具体的 矩阵操作 output = K.dot(inputs, self .kernel) if self .use_bias: output = K.bias_add(output, self .bias, data_format = 'channels_last' ) if self .activation is not None : output = self .activation(output) return output |
2.4 compute_output_shape 函数重写
计算出输出tensor的维度并返回。
1 2 3 4 5 6 | def compute_output_shape( self , input_shape): assert input_shape and len (input_shape) > = 2 assert input_shape[ - 1 ] output_shape = list (input_shape) output_shape[ - 1 ] = self .units return tuple (output_shape) |
2.5 get_config 函数重写
保留一些中间值并以字典的形式返回。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | def get_config( self ): config = { 'units' : self .units, 'activation' : activations.serialize( self .activation), 'use_bias' : self .use_bias, 'kernel_initializer' : initializers.serialize( self .kernel_initializer), 'bias_initializer' : initializers.serialize( self .bias_initializer), 'kernel_regularizer' : regularizers.serialize( self .kernel_regularizer), 'bias_regularizer' : regularizers.serialize( self .bias_regularizer), 'activity_regularizer' : regularizers.serialize( self .activity_regularizer), 'kernel_constraint' : constraints.serialize( self .kernel_constraint), 'bias_constraint' : constraints.serialize( self .bias_constraint) } base_config = super (Dense, self ).get_config() return dict ( list (base_config.items()) + list (config.items())) |
分类:
AI Related / NLP
, AI Related
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
2013-01-01 【096】2012年总结(流水账式)