萌小帅 一路向前

要有朴素的生活和遥远的梦想,不管明天天寒地冻,路遥马亡...^-^...

TensorFlow拾遗(一) 打印网络结构与变量

  使用tensorflow搭建网络之后,如果可视化一下网络的结构与变量,会对网络结构有一个更直观的了解。

  另外,这种方式也可以获得网络输出节点名称,便于pb文件的生成。

  在许多源码中都会包含这一操作,只不过大多可能并没有打印出来

variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # type:list
 1 import tensorflow as tf
 2 import os
 3 def txt_save(data, output_file):
 4     file = open(output_file, 'a')
 5     for i in data:
 6         s = str(i) + '\n'
 7         file.write(s)
 8     file.close()
 9 
10 
11 def network_param(input_checkpoint, output_file=None):
12     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
13     with tf.Session() as sess:
14         saver.restore(sess, input_checkpoint)
15         variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
16         for i in variable:
17             print(i)     # 打印
18         txt_save(variables, output_file)  # 保存txt   二选一
19 
20 if __name__ == '__main__':
21     checkpoint_path = 'ckpt/model.ckpt'
22     output_file = 'network_param.txt'
23     if not os.path.exists(output_file):
24         network_param(checkpoint_path, output_file)

  获得的txt文件部分如下所示,详细的txt文件猛戳这里

<tf.Variable 'Layer1/conv2d/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref>
<tf.Variable 'Layer1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d/kernel:0' shape=(1, 1, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_2/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_2/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/conv2d/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d/kernel:0' shape=(1, 1, 256, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
网络结构输出部分展示

  其实,在打印网络结构与变量获得的结果中,也可以获得输出节点的名称,如下所示:

1 Score/Head/conv2d_1/BiasAdd 
2 BBox/Head/conv2d_1/BiasAdd

 

 

  

posted on 2020-07-15 09:34  墨殇浅尘  阅读(3466)  评论(0编辑  收藏  举报

导航