代码改变世界

把ResNet-L152模型的ckpt文件转化为pb文件

2018-02-15 17:29  Time皇族  阅读(4165)  评论(0编辑  收藏  举报
import tensorflow as tf
from tensorflow.python.tools import freeze_graph

#os.environ['CUDA_VISIBLE_DEVICES']='2'  #设置GPU
model_path  = "D:\\JupyterWorkSpace\\Tensorflow\\Fine-tuning\\tensorflow-resnet-pretrained-20160509\\ResNet-L152.ckpt" #设置model的路径

def main():
    tf.reset_default_graph()
    saver = tf.train.import_meta_graph("D:\\JupyterWorkSpace\\Tensorflow\\Fine-tuning\\tensorflow-resnet-pretrained-20160509\\ResNet-L152.meta")
    #flow = tf.cast(flow, tf.uint8, 'out') #设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
    with tf.Session() as sess:
        saver.restore(sess, model_path)
        #保存图
        tf.train.write_graph(sess.graph_def, './ResNet_L152_retrain/pb_model', 'model_ResNet_L152.pb')
        #把图和参数结构一起
        freeze_graph.freeze_graph('ResNet_L152_retrain/pb_model/model_ResNet_L152.pb',
                                  '',
                                  False,
                                  model_path,
                                  'fc/xw_plus_b',
                                  'save/restore_all',
                                  'save/Const:0',
                                  'ResNet_L152_retrain/pb_model/frozen_model_ResNet_L152.pb',
                                  False,
                                  "")
    print("done")
    
if __name__ == '__main__':
    main()

  总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空): 
1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明) 
2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。 
3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False 
4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。 
5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。 
6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all 
7、filename_tensor_name:(可选)已弃用。默认:save/Const:0 
8、output_graph:(必选)用来保存整合后的模型输出文件。 
9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认) 
10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。 
11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。 

参考:http://blog.csdn.net/yjl9122/article/details/78341689