ValueError: Input 0 of node import/save/Assign was passed float from import/beta1_power:0 incompatib

导入优化的frozen graph时遇到异常。

# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# import graph_def
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def)

在此行中获取异常:

tf.import_graph_def(graph_def)

 ValueError: Input 0 of node import/save/Assign was passed float from import/beta1_power:0 incompatible with expected float_ref. 

解决方案:确保你的pb_file格式正确(类似这样),并尝试在import_graph_def()的'name'参数中设置一些值,以尝试覆盖“import”默认值,如下所示:

import tensorflow as tf

from tensorflow.python.platform import gfile
model_path="/tmp/frozen/dcgan.pb"

# read graph definition
f = gfile.FastGFile(model_path, "rb")
gd = graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

# fix nodes
for node in graph_def.node:
    if node.op == 'RefSwitch':
        node.op = 'Switch'
        for index in xrange(len(node.input)):
            if 'moving_' in node.input[index]:
                node.input[index] = node.input[index] + '/read'
    elif node.op == 'AssignSub':
        node.op = 'Sub'
        if 'use_locking' in node.attr: del node.attr['use_locking']

# import graph into session
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)

 

参考链接:https://stackoverflow.com/questions/51084768/valueerror-input-0-of-node-incompatible-with-expected-float-ref

                    https://github.com/onnx/tensorflow-onnx/issues/77

posted @ 2019-08-08 16:33  鲁太师  阅读(665)  评论(0编辑  收藏  举报