Keras + Flask 提供接口服务的坑~~~
ValueError: Tensor Tensor(**************) is not an element of this graph.
When you create a Model, the session hasn't been restored yet. All placeholders, variables and ops that are defined in Model.init are placed in a new graph, which makes itself a default graph inside with block. This is the key line:
with tf.Graph().as_default():
This means that this instance of tf.Graph() equals to tf.get_default_graph() instance inside with block, but not before or after it. From this moment on, there exist two different graphs.
When you later create a session and restore a graph into it, you can't access the previous instance of tf.Graph() in that session. Here's a short example:
with tf.Graph().as_default() as graph: var = tf.get_variable("var", shape=[3], initializer=tf.zeros_initializer)
This works
with tf.Session(graph=graph) as sess: print( # ok because `sess.graph == graph`
This fails
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta') with tf.Session() as sess: saver.restore(sess, "/tmp/model.ckpt") print( # var is from `graph`, not `sess.graph`!
The best way to deal with this is give names to all nodes, e.g. 'input', 'target', etc, save the model and then look up the nodes in the restored graph by name, something like this:
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta') with tf.Session() as sess: saver.restore(sess, "/tmp/model.ckpt") input_data = sess.graph.get_tensor_by_name('input') target = sess.graph.get_tensor_by_name('target')
This method guarantees that all nodes will be from the graph in session.
Try to start with:
import tensorflow as tf global graph,model graph = tf.get_default_graph()
When you need to use predict:
with graph.as_default():
y = model.predict(X)
1 def get_angiogram_time(video_path): 2 start = time.time() 3 global _MODEL_MA,_MODEL_TIME,_GRAPH_MA, _GRAPH_TIME 4 if _MODEL_MA == None: 5 model_ma = ma_ocr.Training_Predict() 6 model_time = time_ocr.Training_Predict() 7 8 model_ma.build_model() 9 model_time.build_model() 10 11 model_ma.load_model("./model/ma_gur_ctc_model.h5base") 12 model_time.load_model("./model/time_gur_ctc_model.h5base") 13 14 _MODEL_MA = model_ma 15 _MODEL_TIME = model_time 16 17 indexes = _MODEL_MA.predict(video_path) 18 time_dict = _MODEL_TIME.predict(video_path,indexes) 19 end = time.time() 20 print("耗时:%.2f s" % (end-start)) 21 return json.dumps(time_dict)
1 def predict(self, video_path): 2 start = time.time() 3 4 vid = cv2.VideoCapture(video_path) 5 if not vid.isOpened(): 6 raise IOError("Couldn't open webcam or video") 7 # video_fps = vid.get(cv2.CAP_PROP_FPS) 8 9 X = self.load_video_data(vid) 10 y_pred = self.base_model.predict(X) 11 shape = y_pred[:, :, :].shape # 2: 12 out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:, 13 :seq_len] # 2: 14 print()
当实行到第10行 :y_pred = self.base_model.predict(X)
就会抛错:Cannot use the given session to evaluate tensor: the tensor's graph is different from the session's graph.
1 def get_angiogram_time(video_path): 2 start = time.time() 3 global _MODEL_MA,_MODEL_TIME,_GRAPH_MA, _GRAPH_TIME 4 if _MODEL_MA == None: 5 model_ma = ma_ocr.Training_Predict() 6 model_time = time_ocr.Training_Predict() 7 8 model_ma.build_model() 9 model_time.build_model() 10 11 model_ma.load_model("./model/ma_gur_ctc_model.h5base") 12 model_time.load_model("./model/time_gur_ctc_model.h5base") 13 14 _MODEL_MA = model_ma 15 _MODEL_TIME = model_time 16 _GRAPH_MA = tf.get_default_graph() 17 _GRAPH_TIME = tf.get_default_graph() 18 19 with _GRAPH_MA.as_default(): 20 indexes = _MODEL_MA.predict(video_path) 21 with _GRAPH_TIME.as_default(): 22 time_dict = _MODEL_TIME.predict(video_path,indexes) 23 end = time.time() 24 print("耗时:%.2f s" % (end-start)) 25 return json.dumps(time_dict)
PS:问了一下专门做AI的朋友,他们公司是用TensorFlow Server提供对外服务的,我最近也要研究一下Tensorflow Server,本人是个AI小白,刚刚入门,写的不对还请指正,谢谢!