TensorFlow学习笔记4——变量共享
因为最近在研究生成对抗网络GAN,在读别人的代码时发现了 with tf.variable_scope(self.name_scope_conv, reuse = reuse): 这样一条语句,查阅官方文档时明白了这是TensorFlow的变量共享机制。
举个例子:当我们研究生成对抗网络GAN的时候,判别器的任务是,如果接收到的是生成器生成的图像,判别器就尝试优化自己的网络结构来使自己输出0,如果接收到的是来自真实数据的图像,那么就尝试优化自己的网络结构来使自己输出1。也就是说,生成图像和真实图像经过判别器的时候,要共享同一套变量,所以TensorFlow引入了变量共享机制。
变量共享主要涉及到两个函数: tf.get_variable(<name>, <shape>, <initializer>) 和 tf.variable_scope(<scope_name>) 。
1. tf.get_variable(<name>, <shape>, <initializer>)
例如,我们搭建一个卷积层:
def conv_relu(input, kernel_shape, bias_shape): # Create variable named "weights". weights = tf.get_variable("weights", kernel_shape, initializer=tf.random_normal_initializer()) # Create variable named "biases". biases = tf.get_variable("biases", bias_shape, initializer=tf.constant_initializer(0.0)) conv = tf.nn.conv2d(input, weights, strides=[1, 1, 1, 1], padding='SAME') return tf.nn.relu(conv + biases)
然后,我们调用两次:
input1 = tf.random_normal([1,10,10,32]) input2 = tf.random_normal([1,20,20,32]) x = conv_relu(input1, kernel_shape=[5, 5, 1, 32], bias_shape=[32]) x = conv_relu(x, kernel_shape=[5, 5, 32, 32], bias_shape = [32]) # This fails.
会发现报错信息。因为执行的命令不明确:第二次调用时是创建一套新的变量(weights,biases)还是再次使用已存在的那一套变量(第一次调用时生成的weights和biases)呢?
这时就需要用到第二个函数: tf.variable_scope(<scope_name>)
2. tf.variable_scope(<scope_name>)
请看例子:
def my_image_filter(input_images): with tf.variable_scope("conv1"): # Variables created here will be named "conv1/weights", "conv1/biases". relu1 = conv_relu(input_images, [5, 5, 1, 32], [32]) with tf.variable_scope("conv2"): # Variables created here will be named "conv2/weights", "conv2/biases". return conv_relu(relu1, [5, 5, 32, 32], [32])
在不同的域内会生成不同的变量。
如果想要变量共享,TensorFlow提供了两种方法:
1. 设置 reuse=True
with tf.variable_scope("model"): output1 = my_image_filter(input1) with tf.variable_scope("model", reuse=True): output2 = my_image_filter(input2)
2. 调用 scope.reuse_variables()
with tf.variable_scope("model") as scope: output1 = my_image_filter(input1) scope.reuse_variables() output2 = my_image_filter(input2)
注:在官方文档的最后有这样一段话:Since depending on exact string names of scopes can feel dangerous, it's also possible to initialize a variable scope based on another one:
with tf.variable_scope("model") as scope: output1 = my_image_filter(input1) with tf.variable_scope(scope, reuse=True): output2 = my_image_filter(input2)