TensorFlow 变量作用域 变量管理 共享变量

当我们的神经网络拥有很复杂的模块时,我们使用TensorFlow提供的变量作用域(tf.variable_scope)来管理这些变量。

变量作用域的两个核心方法:

tf.get_variable(<name>, <shape>, <initializer>): 通过所给的名字创建或是返回一个变量.
tf.variable_scope(<scope_name>, <reuse>): 通过 tf.get_variable()为变量名指定命名空间.

上一篇文章中,我们已经有用到这两个方法,这一篇我们聚焦在这两方法的具体说明上。

tf.get_variable方法在创建初始化变量的时候与tf.Variable是完全一样的。

不过tf.get_variable可以通过tf.variable_scope生成的上下文管理器获取已经生成的变量值。

eg1:

import tensorflow as tf
with tf.variable_scope("foo"):
    v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))

with tf.variable_scope("foo"):
    v=tf.get_variable("v",[1])
    

执行这段代码,会报错。因为在命名空间foo中,name为v的变量已经存在。

在声明命名空间的时候,将reuse设置为True,这样tf.get_variable将直接获取已经声明的变量。

 

eg2:

import tensorflow as tf
with tf.variable_scope("foo"):
    v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))

with tf.variable_scope("foo", reuse=True):
    v1=tf.get_variable("v",[1])
    print(v==v1);

执行这段代码,输出为True

不过,reuse定义为True的时候,tf.get_variable只能获取已经拥有的变量。如果命名空间中没有定义这个变量就会报错。

比如讲上面代码的第二个命名空间名字改为bar再次执行就会报错。

 

eg3:

import tensorflow as tf
with tf.variable_scope("foo"):
    v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))

with tf.variable_scope("bar", reuse=True):
    v1=tf.get_variable("v",[1])
    print(v==v1);

执行这段代码会报错。

 

变量管理器还可以通过名称轻松访问变量。

eg4:

import tensorflow as tf
with tf.variable_scope("foo"):
    v = tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
    print(v.name)

with tf.variable_scope("foo"):
    with tf.variable_scope("bar"):
        v1 = tf.get_variable("v",[1])
        print(v1.name);

with tf.variable_scope("",reuse=True):
    v2 = tf.get_variable("foo/v");
    print(v2==v)
    print(v2==v1)
    v3 = tf.get_variable("foo/bar/v")
    print(v3==v)
    print(v3==v1)

输出结果:

foo/v:0
foo/bar/v:0
True
False
False
True

 

灵活的使用变量管理器我们可以在复杂的神经网络结构中大大的提高代码的可读性。

了解了tf.variable_scope 和 tf.get_variable的作用和用法以后,再回过头读读之前的代码,领悟一下他们在神经网络结构中发挥了怎样的作用。

 

使用实例

这里有一些指向怎么使用变量作用域的文件.特别是,他被大量用于 时间递归神经网络sequence-to-sequence模型,

FileWhat's in it?
models/image/cifar10.py 图像中检测对象的模型.
models/rnn/rnn_cell.py 时间递归神经网络的元方法集.
models/rnn/seq2seq.py 为创建sequence-to-sequence模型的方法集.

参考链接

http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/variable_scope.html

《Tensorflow+实战Google深度学习框架》5.3节

posted @ 2017-12-13 22:14  郭老猫  阅读(1810)  评论(0编辑  收藏  举报