查看tensorflow pb模型文件的节点信息

查看tensorflow pb模型文件的节点信息:

import tensorflow as tf
with tf.Session() as sess:
    with open('./quantized_model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) 
        print graph_def
        

效果:

复制代码
# ...
node {
  name: "FullyConnected/BiasAdd"
  op: "BiasAdd"
  input: "FullyConnected/MatMul"
  input: "FullyConnected/b/read"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "data_format"
    value {
      s: "NHWC"
    }
  }
}
node {
  name: "FullyConnected/Softmax"
  op: "Softmax"
  input: "FullyConnected/BiasAdd"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
library {
}
复制代码

 

 

参考:https://tang.su/2017/01/export-TensorFlow-network/

https://github.com/tensorflow/tensorflow/issues/15689

一些核心代码:

import tensorflow as tf
with tf.Session() as sess:
    with open('./graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) 
        print graph_def
        output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
        print(sess.run(output))

 

 

This is part of my Tensorflow frozen graph, I have named the input and output nodes.

>>> g.ParseFromString(open('frozen_graph.pb','rb').read())
>>> g
node {
  name: "input"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 68
        }
      }
    }
  }
}
...
node {
  name: "output"
  op: "Softmax"
  input: "add"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

I ran this model by the following code
(CELL is name of directory where my file is located)

final String MODEL_FILE = "file:///android_asset/" + CELL + "/optimized_graph.pb" ;
final String INPUT_NODE = "input" ;
final String OUTPUT_NODE = "output" ;
final int[] INPUT_SIZE = {1,68} ;
float[] RESULT = new float[8];

inferenceInterface = new TensorFlowInferenceInterface();
inferenceInterface.initializeTensorFlow(getAssets(),MODEL_FILE) ;
inferenceInterface.fillNodeFloat(INPUT_NODE,INPUT_SIZE,input);

and finally

inferenceInterface.readNodeFloat(OUTPUT_NODE,RESULT);
posted @   bonelee  阅读(6645)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2017-02-23 Luke 5—— 可视化 Lucene 索引查看工具,可以查看ES的索引
2017-02-23 Apache Flink vs Apache Spark——感觉二者是互相抄袭啊 看谁的好就抄过来 Flink支持在runtime中的有环数据流,这样表示机器学习算法更有效而且更有效率
2017-02-23 druid相关的时间序列数据库——也用到了倒排相关的优化技术
2017-02-23 时间序列数据库——索引用ES、聚合分析时加载数据用什么?docvalues的列存储貌似更优优势一些。那分布式计算呢?ES做
2017-02-23 时间序列数据库——索引用ES、聚合分析时加载数据用什么?docvalues的列存储貌似更优优势一些
2017-02-23 时间序列数据库概览——基于文件(RRD)、K/V数据库(influxDB)、关系型数据库
2017-02-23 ES索引瘦身 禁用_source后需要设置field store才能获取数据 否则无法显示搜索结果
点击右上角即可分享
微信分享提示