TensorFlow

官网

Tensorflow源码分析

A、基本概念

  1. Graph

  2. Tensor 

  3. Session

B、Tools

  1. Checkpoint  .Ckpt

  2. Pb

  3. .Ckpt To .Pb

  4. TensorBoard

 B.1  .Ckpt 模型加载

1. 模型的保存

import tensorflow as tf

def store_model_ckpt(ckpt_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    #模型的保存必须有变量
    c = tf.Variable(1, name='c')
    a = tf.add(x, y, name='op')
    result = tf.add(a, c)

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
    
        saver = tf.train.Saver()
    
        #如果只保存其中一部分变量,则使用下面代码,用列表或者字典都可以
        #saver = tf.train.Saver([x, y])
    
        #这里面有参数global_step=50,当训练50步便保存模型
        saver.save(sess, ckpt_file_path)
        # test
        feed_dict = {x: 2, y: 3}
        print(sess.run(result, feed_dict))

def main():
    ckpt_file_path = "./ckpt/model.ckpt"
    store_model_ckpt(ckpt_file_path)

if __name__ == '__main__':
    main()

结果:6

程序生成并保存四个文件

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

2. 模型恢复加载

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt():
    with tf.Session() as sess:
        #step1:加载模型结构
        saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta')
        #step2:只需要指定目录就可以恢复所有变量信息
        saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))
        
        #直接获取保存的变量
        print(sess.run('c:0'))
        
        #获取placeholder变量,通过get_tensor_by_name
        x = sess.graph.get_tensor_by_name('x:0')
        y = sess.graph.get_tensor_by_name('y:0')
        
        #获取需要进行计算的op算子,此op为加法
        op = sess.graph.get_tensor_by_name('op:0')
        
        #加入新的op操作,新的op为乘法
        new_op = tf.multiply(op, 2)
        
        #test
        feed_dict = {x:2, y:3}
        
        result = sess.run(new_op,feed_dict)
        print(result)

def main():
    restore_model_ckpt()
    
if __name__ == '__main__':
    main()

结果:10

  1. 首先还原模型结构

  2. 然后还原变量(参数)信息

  3. 最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
  并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

  针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

  同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

B. 2 Pb模型文件

 1. pb模型保存

import tensorflow as tf
from tensorflow.python.framework import graph_util

def store_model_pb(pb_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    a = tf.add(x, y)
    #该op算子应该加上name
    op = tf.add(a, b, name='op')
    
    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        sess.run(init)
        
        #导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算
        graph_def = tf.get_default_graph().as_graph_def()
        
        #将图中的变量及其取值转化为常量,同时将图中的不必要的节点去掉
        output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['op'])
        
        with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
            f.write(output_graph_def.SerializeToString())
        
        #test
        feed_dict = {x: 2, y: 3}
        print(sess.run(op, feed_dict))

def main():
    pb_file_path = "model.pb"
    store_model_pb(pb_file_path)
    
if __name__ == '__main__':
    main()
    

结果:6 

  在当前文件下面生成model.pb文件

2. pb模型加载

import tensorflow as tf
from tensorflow.python.platform import gfile
    
def restore_model_pb(pb_file_path):
    with tf.Session() as sess:
        with gfile.FastGFile(pb_file_path, 'rb') as f:
            graph_def = tf.GraphDef()
            #转换成字符串形式
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
       
        #获取placeholder的变量
        x = sess.graph.get_tensor_by_name('x:0')
        y = sess.graph.get_tensor_by_name('y:0')
        
        #获取op算子
        op = sess.graph.get_tensor_by_name('op:0')
        
        feed_dict = {x: 2, y:3}
        result = sess.run(op,feed_dict)
        print(result)
          
def main():
    pb_file_path = "model.pb"
    restore_model_pb(pb_file_path)
    
if __name__ == '__main__':
    main()

结果:5

B 3. 将.Ckpt 转换为.Pb

  但很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

    TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,而且保存的模型可以移植到Android平台。

一、CKPT 转换成 PB格式

  将CKPT 转换成 PB格式的文件的过程可简述如下:

    1. 通过传入 CKPT 模型的路径得到模型的图和变量数据
    2. 通过 import_meta_graph 导入模型中的图
    3. 通过 saver.restore 从模型中恢复图中各个变量的数据
    4. 通过 graph_util.convert_variables_to_constants 将模型持久化

Code:freeze_graph.py

import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(ckpt_file_path, pb_file_path):
    #“input:0”是张量的名称,而"input"表示的是节点的名称。
    #此处输入的应该是节点的名称
    output_node_names = "op"
    #首先恢复图结构
    saver = tf.train.import_meta_graph(ckpt_file_path+'.meta',clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    
    with tf.Session() as sess:
        #恢复图并得到数据
        saver.restore(sess,ckpt_file_path)
        output_graph_def = graph_util.convert_variables_to_constants(
                sess=sess,
                input_graph_def=input_graph_def,
                #如果有多个输出节点
                output_node_names=output_node_names.split(","))
        with tf.gfile.GFile(pb_file_path,"wb") as f:
            f.write(output_graph_def.SerializeToString())
            print("%d ops in the final graph." % len(output_graph_def.node)) 
                     
def main():
    # 输入ckpt模型路径
    model_folder = "D:\AI\Ckpt\TestCkpt\ckpt"
    #检查目录下ckpt文件状态是否可用
    checkpoint = tf.train.get_checkpoint_state(model_folder) 
    #得ckpt文件路径
    ckpt_file_path = checkpoint.model_checkpoint_path 
    
    # 输出pb模型的路径
    pb_file_path="frozen_model.pb"
    
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(ckpt_file_path,pb_file_path)
    
if __name__ == '__main__':
    main()

结果:生成 frozen_model.pb文件,可以采用上面pb模型加载的方法测试该pb文件

说明:

1、函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

 2、在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。

3、源码中通过graph = tf.get_default_graph()获得默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。

4、上面以及说明:在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。因此,其他网络模型,也可以通过简单的修改输出的节点名称output_node_names,将ckpt转为pb文件 。

       PS:注意节点名称,应包含name_scope 和 variable_scope命名空间,并用“/”隔开,如"InceptionV3/Logits/SpatialSqueeze"

B.4 TensorBoard

  1. 生成graph

# -*- coding: utf-8 -*-
"""
Created on Sat Dec 22 09:49:04 2018

@author: weilong
"""

import tensorflow as tf

#定义简单的计算图,实现向量加法的操作
with tf.name_scope("imput1"):
    input1 = tf.constant([1.0, 2.0, 3.0], name="input1")
with tf.name_scope("input2"):
    input2 = tf.Variable(tf.random_uniform([3]), name="input2")
output = tf.add_n([input1, input2], name="add")

#生成写日志的writer,并将当前的tensorflow计算图写入日志
writer = tf.summary.FileWriter("./log", tf.get_default_graph())
writer.close()

 2. 将训练好的model.pb文件在tensorboard中展示其网络结构

import tensorflow as tf

model = 'model.pb' #请将这里的pb文件路径改为自己的
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)

执行以上代码就会生成文件在log/events.out.tfevents.1535079670.DESKTOP-5IRM000。

 在tensorboard中加载:

tensorboard --logdir=\path\to\log

在浏览器中

拷贝网站链接在浏览器中即可。

参考:https://blog.csdn.net/guyuealian/article/details/82218092

posted @ 2018-12-19 15:17  weilongyitian  阅读(688)  评论(0编辑  收藏  举报