tensorflow的variable、variable_scope和get_variable的用法和区别
在tensorflow中,可以使用tf.Variable来创建一个变量,也可以使用tf.get_variable来创建一个变量,但是在一个模型需要使用其他模型的变量时,tf.get_variable就派上大用场了。
先分别介绍两个函数的用法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import tensorflow as tf var1 = tf.Variable( 1.0 ,name = 'firstvar' ) print ( 'var1:' ,var1.name) var1 = tf.Variable( 2.0 ,name = 'firstvar' ) print ( 'var1:' ,var1.name) var2 = tf.Variable( 3.0 ) print ( 'var2:' ,var2.name) var2 = tf.Variable( 4.0 ) print ( 'var2:' ,var2.name) get_var1 = tf.get_variable(name = 'firstvar' ,shape = [ 1 ],dtype = tf.float32,initializer = tf.constant_initializer( 0.3 )) print ( 'get_var1:' ,get_var1.name) get_var1 = tf.get_variable(name = 'firstvar1' ,shape = [ 1 ],dtype = tf.float32,initializer = tf.constant_initializer( 0.4 )) print ( 'get_var1:' ,get_var1.name) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print ( 'var1=' ,var1. eval ()) print ( 'var2=' ,var2. eval ()) print ( 'get_var1=' ,get_var1. eval ()) |
结果如下:
我们来分析一下代码,tf.Varibale是以定义的变量名称为唯一标识的,如var1,var2,所以可以重复地创建name='firstvar'的变量,但是tensorflow会给它们按顺序取后缀,如firstvar_1:0,firstval_2:0,...,如果没有制定名字,系统会自动加上一个名字Variable:0。而且由于tf.Varibale是以定义的变量名称为唯一标识的,所以当第二次命名同一个变量名时,第一个变量就会被覆盖,所以var1由1.0变成2.0。
对于tf.get_variable,它是以指定的name属性为唯一标识,而不是定义的变量名称,所以不能同时定义两个变量name是相同的,例如下面这种就会报错:
1 2 3 4 | 1 get_var1 = tf.get_variable(name = 'a' ,shape = [ 1 ],dtype = tf.float32,initializer = tf.constant_initializer( 0.3 )) 2 print ( 'get_var1:' ,get_var1.name) 3 get_var2 = tf.get_variable(name = 'a' ,shape = [ 1 ],dtype = tf.float32,initializer = tf.constant_initializer( 0.4 )) 4 print ( 'get_var1:' ,get_var1.name) |
这样就会报错了。如果我们想声明两次相同name的变量,这时variable_scope就派上用场了,可以使用variable_scope将它们分开:
1 2 3 4 5 6 7 | import tensorflow as tf with tf.variable_scope( 'test1' ): get_var1 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) with tf.variable_scope( 'test2' ): get_var2 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) print ( 'get_var1:' ,get_var1.name) print ( 'get_var2:' ,get_var2.name) |
这样就不会报错了,variable_scope相当于声明了作用域,这样在不同的作用域存在相同的变量就不会冲突了,结果如下:
当然,scope还支持嵌套:
1 2 3 4 5 6 7 | import tensorflow as tf with tf.variable_scope( 'test1' ,): get_var1 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) with tf.variable_scope( 'test2' ,): get_var2 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) print ( 'get_var1:' ,get_var1.name) print ( 'get_var2:' ,get_var2.name) |
输出结果为:
怎么样?可以对照上面的结果体会一下不同。那么如何通过get_variable来实现变量共享呢?这就要用到variable_scope里的一个属性:reuse,顾名思义嘛,当把reuse设置成True时就可以了,它表示使用已经定义过的变量,这是get_variable就不会再创建新的变量,而是去找与name相同的变量:
1 2 3 4 5 6 7 8 9 10 11 12 13 | import tensorflow as tf with tf.variable_scope( 'test1' ,): get_var1 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) with tf.variable_scope( 'test2' ,): get_var2 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) print ( 'get_var1:' ,get_var1.name) print ( 'get_var2:' ,get_var2.name) with tf.variable_scope( 'test1' ,reuse = True ): get_var3 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) with tf.variable_scope( 'test2' ,): get_var4 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) print ( 'get_var3:' ,get_var3.name) print ( 'get_var4:' ,get_var4.name) |
输出结果如下:
当然前面说过,reuse=True是使用前面已经创建过的变量,如果仅仅只有从第八行到最后的代码,也会报错的,如果还是想这么做,就需要把reuse属性设置成tf.AUTO_REUSE
1 2 3 4 5 6 7 | import tensorflow as tf with tf.variable_scope( 'test1' ,reuse = tf.AUTO_REUSE): get_var3 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) with tf.variable_scope( 'test2' ,): get_var4 = tf.get_variable(name = 'firstvar' ,shape = [ 2 ],dtype = tf.float32) print ( 'get_var3:' ,get_var3.name) print ( 'get_var4:' ,get_var4.name) |
此时就不会报错,tf.AUTO_REUSE可以实现第一次调用variable_scope时,传入的reuse值为False,再次调用时,传入reuse的值就会自动变为True。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 如何调用 DeepSeek 的自然语言处理 API 接口并集成到在线客服系统
· 【译】Visual Studio 中新的强大生产力特性
· 2025年我用 Compose 写了一个 Todo App