Loading [Contrib]/a11y/accessibility-menu.js

alex_bn_lee

导航

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

统计

【516】keras 源码分析之 Dense

参考:keras源码分析之Layer

参考:keras源码分析之Dense


  本文主要讲解一下 Dense 层的源码,Dense 层即最常用的全连接层,代码很简单,主要是重写了 build 与 call 方法,在我们自定义 Layer 时,也可以参考该层的实现。但是不需要这么复杂,只要写出必要的部分就可以了,参见下一篇博客。

1. Layer 类的相关说明

参考:TensorFlow函数:tf.layers.Layer —— W3Cschool TensorFlow 官方文档

参考:关于 Keras 网络层 —— keras 中文文档

  基础层类。这是所有层都继承的类,实现了通用的基础结构功能。层是实现常见神经网络操作的类,例如卷积、批量规范等。这些操作需要管理变量、损失和更新,以及将 TensorFlow 操作应用于输入张量。用户只需实例化它,然后将其视为可调用的。

  我们建议 Layer 的子代实现以下方法:

  • __init__ ():在成员变量中保存配置
  • build():当我们知道输入和 dtype 的形状时,从 __call__ 调用一次。应该有对 add_variable() 的调用,然后调用高级的 build() (设置为 self.built = True,这在用户想要在第一个 __call__ 之前手动调用 build() 时很好)。
  • * call():确认 build() 已被调用一次后调用 __call__。实际上应该执行将层应用于输入张量的逻辑(应该作为第一个参数传入)。

2. Dense 源码解读

2.1 __init__ 函数重写

  构造方法没什么好说的,就是一些简单的赋值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from keras.layers import Layer
 
class Dense(Layer):
    def __init__(self, units,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(Dense, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(min_ndim=2)
        self.supports_masking = True

  

2.2 build 函数重写

  build 方法中定义了两个 Variable 即权重,最后把 built 参数置为 True。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def build(self, input_shape):
    assert len(input_shape) >= 2
    # 维度取 input_shape 的最后一维
    # 正好进行后面的叉乘
    input_dim = input_shape[-1]
 
    # 设置权重矩阵,维度为 (input_dim, self.units),用于叉乘  
    self.kernel = self.add_weight(shape=(input_dim, self.units),
                                  initializer=self.kernel_initializer,
                                  name='kernel',
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
    if self.use_bias:
        # 设置偏置
        self.bias = self.add_weight(shape=(self.units,),
                                    initializer=self.bias_initializer,
                                    name='bias',
                                    regularizer=self.bias_regularizer,
                                    constraint=self.bias_constraint)
    else:
        self.bias = None
    self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
    self.built = True

  

2.3 call 函数重写

  call 方法把输入值与 build 方法中定义的权重进行了点积的操作,然后与 build 中的偏移量进行相加,最后经过激活函数返回最终的输出结果。

1
2
3
4
5
6
7
8
def call(self, inputs):
    # 具体的 矩阵操作
    output = K.dot(inputs, self.kernel)
    if self.use_bias:
        output = K.bias_add(output, self.bias, data_format='channels_last')
    if self.activation is not None:
        output = self.activation(output)
    return output

  

2.4 compute_output_shape 函数重写

  计算出输出tensor的维度并返回。

1
2
3
4
5
6
def compute_output_shape(self, input_shape):
    assert input_shape and len(input_shape) >= 2
    assert input_shape[-1]
    output_shape = list(input_shape)
    output_shape[-1] = self.units
    return tuple(output_shape)

  

2.5 get_config 函数重写

  保留一些中间值并以字典的形式返回。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def get_config(self):
    config = {
        'units': self.units,
        'activation': activations.serialize(self.activation),
        'use_bias': self.use_bias,
        'kernel_initializer': initializers.serialize(self.kernel_initializer),
        'bias_initializer': initializers.serialize(self.bias_initializer),
        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
        'activity_regularizer':
            regularizers.serialize(self.activity_regularizer),
        'kernel_constraint': constraints.serialize(self.kernel_constraint),
        'bias_constraint': constraints.serialize(self.bias_constraint)
    }
    base_config = super(Dense, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  

posted on   McDelfino  阅读(540)  评论(0编辑  收藏  举报

编辑推荐:
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
历史上的今天:
2013-01-01 【096】2012年总结(流水账式)
点击右上角即可分享
微信分享提示