Tensorflow七种初始化函数
分类:
TensorFlow
一、tf.constant_initializer(value)
作用:将变量初始化为给定的常量,初始化一切所提供的值。
二、tf.zeros_initializer()
作用:将变量设置为全0;也可以简写为tf.Zeros()
三、tf.ones_initializer()
作用:将变量设置为全1;可简写为tf.Ones()
四、tf.random_normal_initializer(mean,stddev)
作用:将变量初始化为满足正太分布的随机值,主要参数(正太分布的均值和标准差),用所给的均值和标准差初始化均匀分布。
五、tf.truncated_normal_initializer(mean,stddev,seed,dtype)
作用:将变量初始化为满足正太分布的随机值,但如果随机出来的值偏离平均值超过2个标准差,那么这个数将会被重新随机。
- mean:用于指定均值;
- stddev用于指定标准差;
- seed:用于指定随机数种子;
- dtype:用于指定随机数的数据类型。通常只需要设定一个标准差stddev这一个参数就可以。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | @tf_export ( "initializers.truncated_normal" , "truncated_normal_initializer" ) class TruncatedNormal(Initializer): """Initializer that generates a truncated normal distribution. These values are similar to values from a `random_normal_initializer` except that values more than two standard deviations from the mean are discarded and re-drawn. This is the recommended initializer for neural network weights and filters. Args: mean: a python scalar or a scalar tensor. Mean of the random values to generate. 一个python标量或一个标量张量。要生成的随机值的均值 stddev: a python scalar or a scalar tensor. Standard deviation of the random values to generate.一个python标量或一个标量张量。要生成的随机值的标准偏差。 seed: A Python integer. Used to create random seeds. See `tf.set_random_seed` for behavior.一个Python整数。用于创建随机种子。查看 tf.set_random_seed 行为。 dtype: The data type. Only floating point types are supported.数据类型。只支持浮点类型。 """ def __init__( self , mean = 0.0 , stddev = 1.0 , seed = None , dtype = dtypes.float32): self .mean = mean self .stddev = stddev self .seed = seed self .dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) def __call__( self , shape, dtype = None , partition_info = None ): if dtype is None : dtype = self .dtype return random_ops.truncated_normal( shape, self .mean, self .stddev, dtype, seed = self .seed) def get_config( self ): return { "mean" : self .mean, "stddev" : self .stddev, "seed" : self .seed, "dtype" : self .dtype.name } |
举例:bert中初始化token_type_embeddings、embedding_table时,假设token_type_embeddings服从正态分布
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | def embedding_postprocessor(input_tensor, use_token_type = False , token_type_ids = None , token_type_vocab_size = 16 , token_type_embedding_name = "token_type_embeddings" , use_position_embeddings = True , position_embedding_name = "position_embeddings" , initializer_range = 0.02 , max_position_embeddings = 512 , dropout_prob = 0.1 ): ... if use_token_type: if token_type_ids is None : raise ValueError( "`token_type_ids` must be specified if" "`use_token_type` is True." ) token_type_table = tf.get_variable( name = token_type_embedding_name, shape = [token_type_vocab_size, width], initializer = create_initializer(initializer_range)) ... def create_initializer(initializer_range = 0.02 ): """Creates a `truncated_normal_initializer` with the given range.""" return tf.truncated_normal_initializer(stddev = initializer_range) |
六、tf.random_uniform_initializer(a,b,seed,dtype)
作用:从a到b均匀初始化,将变量初始化为满足均匀分布的随机值,主要参数(最大值,最小值)。
七、tf.uniform_unit_scaling_initializer(factor,seed,dtypr)
作用:将变量初始化为满足均匀分布但不影响输出数量级的随机值
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现
2018-12-24 Python中获取字典中最值对应的键