TensorFlow拾遗(一) 打印网络结构与变量
使用tensorflow搭建网络之后,如果可视化一下网络的结构与变量,会对网络结构有一个更直观的了解。
另外,这种方式也可以获得网络输出节点名称,便于pb文件的生成。
在许多源码中都会包含这一操作,只不过大多可能并没有打印出来
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # type:list
1 import tensorflow as tf 2 import os 3 def txt_save(data, output_file): 4 file = open(output_file, 'a') 5 for i in data: 6 s = str(i) + '\n' 7 file.write(s) 8 file.close() 9 10 11 def network_param(input_checkpoint, output_file=None): 12 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 13 with tf.Session() as sess: 14 saver.restore(sess, input_checkpoint) 15 variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 16 for i in variable: 17 print(i) # 打印 18 txt_save(variables, output_file) # 保存txt 二选一 19 20 if __name__ == '__main__': 21 checkpoint_path = 'ckpt/model.ckpt' 22 output_file = 'network_param.txt' 23 if not os.path.exists(output_file): 24 network_param(checkpoint_path, output_file)
获得的txt文件部分如下所示,详细的txt文件猛戳这里:
<tf.Variable 'Layer1/conv2d/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref> <tf.Variable 'Layer1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/conv2d/kernel:0' shape=(1, 1, 64, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_2/gamma:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_2/beta:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/Downsample/conv2d/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/gamma:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/beta:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/conv2d/kernel:0' shape=(1, 1, 256, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
其实,在打印网络结构与变量获得的结果中,也可以获得输出节点的名称,如下所示:
1 Score/Head/conv2d_1/BiasAdd 2 BBox/Head/conv2d_1/BiasAdd
作者:墨殇浅尘
-------------------------------------------
算法届的小学生,虔诚而不迷茫,做一个懂生活并有趣的人!
如果觉得这篇文章对你有小小的帮助的话,记得在右下角点个 [推荐] 噢! 欢迎共同交流机器学习,机器视觉,深度学习~
欢迎转载,转载请声明出处!