Keras + Flask 提供接口服务的坑~~~

最近在搞Keras,训练完的模型要提供个预测服务出来。就想了个办法,通过Flask提供一个http服务,后来发现也能正常跑,但是每次预测都需要加载模型,效率非常低。

然后就把模型加载到全局,每次要用的时候去拿来用就行了,可是每次去拿的时候,都会报错.

如:

ValueError: Tensor Tensor(**************) is not an element of this graph.

这个问题就是在你做预测的时候,他加载的图,不是你第一次初始化模型时候的图,所以图里面没有模型里的那些参数和节点

在网上找了个靠谱的解决方案,亲测有效,原文:https://wolfx.cn/flask-keras-server/

 

解决方式如下:

When you create a Model, the session hasn't been restored yet. All placeholders, variables and ops that are defined in Model.init are placed in a new graph, which makes itself a default graph inside with block. This is the key line:

with tf.Graph().as_default():
  ...

This means that this instance of tf.Graph() equals to tf.get_default_graph() instance inside with block, but not before or after it. From this moment on, there exist two different graphs.

When you later create a session and restore a graph into it, you can't access the previous instance of tf.Graph() in that session. Here's a short example:

with tf.Graph().as_default() as graph:
  var = tf.get_variable("var", shape=[3], initializer=tf.zeros_initializer)

This works

with tf.Session(graph=graph) as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(var))  # ok because `sess.graph == graph`

This fails

saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
with tf.Session() as sess:
  saver.restore(sess, "/tmp/model.ckpt")
  print(sess.run(var))   # var is from `graph`, not `sess.graph`!

The best way to deal with this is give names to all nodes, e.g. 'input', 'target', etc, save the model and then look up the nodes in the restored graph by name, something like this:

saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
with tf.Session() as sess:
  saver.restore(sess, "/tmp/model.ckpt")      
  input_data = sess.graph.get_tensor_by_name('input')
  target = sess.graph.get_tensor_by_name('target')

This method guarantees that all nodes will be from the graph in session.

Try to start with:

import tensorflow as tf
global graph,model
graph = tf.get_default_graph()

When you need to use predict:

with graph.as_default():
     y = model.predict(X)

------------------------------------------------------------华丽的分割线------------------------------------------------------------

 

下面就上我自己的代码,解释一下如何使用:

我原来的代码是这样的:

  

 1 def get_angiogram_time(video_path):
 2     start = time.time()
 3     global _MODEL_MA,_MODEL_TIME,_GRAPH_MA, _GRAPH_TIME
 4     if _MODEL_MA == None:
 5         model_ma = ma_ocr.Training_Predict()
 6         model_time = time_ocr.Training_Predict()
 7 
 8         model_ma.build_model()
 9         model_time.build_model()
10 
11         model_ma.load_model("./model/ma_gur_ctc_model.h5base")
12         model_time.load_model("./model/time_gur_ctc_model.h5base")
13 
14         _MODEL_MA = model_ma
15         _MODEL_TIME = model_time
16 
17     indexes = _MODEL_MA.predict(video_path)
18     time_dict = _MODEL_TIME.predict(video_path,indexes)
19     end = time.time()
20     print("耗时:%.2f s" % (end-start))
21     return json.dumps(time_dict)
 1     def predict(self, video_path):
 2         start = time.time()
 3 
 4         vid = cv2.VideoCapture(video_path)
 5         if not vid.isOpened():
 6             raise IOError("Couldn't open webcam or video")
 7         # video_fps = vid.get(cv2.CAP_PROP_FPS)
 8 
 9         X = self.load_video_data(vid)
10         y_pred = self.base_model.predict(X)
11         shape = y_pred[:, :, :].shape  # 2:
12         out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:,
13               :seq_len]  # 2:
14         print()

 

当实行到第10行 :y_pred = self.base_model.predict(X)

就会抛错:Cannot use the given session to evaluate tensor: the tensor's graph is different from the session's graph.

 

大致意思就是:当前session里的图和模型中的图的各种参数不匹配

 

修改后代码:

 

 1 def get_angiogram_time(video_path):
 2     start = time.time()
 3     global _MODEL_MA,_MODEL_TIME,_GRAPH_MA, _GRAPH_TIME
 4     if _MODEL_MA == None:
 5         model_ma = ma_ocr.Training_Predict()
 6         model_time = time_ocr.Training_Predict()
 7 
 8         model_ma.build_model()
 9         model_time.build_model()
10 
11         model_ma.load_model("./model/ma_gur_ctc_model.h5base")
12         model_time.load_model("./model/time_gur_ctc_model.h5base")
13 
14         _MODEL_MA = model_ma
15         _MODEL_TIME = model_time
16         _GRAPH_MA = tf.get_default_graph()
17         _GRAPH_TIME = tf.get_default_graph()
18 
19     with _GRAPH_MA.as_default():
20         indexes = _MODEL_MA.predict(video_path)
21     with _GRAPH_TIME.as_default():
22         time_dict = _MODEL_TIME.predict(video_path,indexes)
23     end = time.time()
24     print("耗时:%.2f s" % (end-start))
25     return json.dumps(time_dict)

主要修改在第16,17,19,21行

定义了一个全局的图,每次都用这个图

 

 

完美解决~

 

PS:问了一下专门做AI的朋友,他们公司是用TensorFlow Server提供对外服务的,我最近也要研究一下Tensorflow Server,本人是个AI小白,刚刚入门,写的不对还请指正,谢谢!

posted @ 2018-12-28 11:33  对代码一无所知  阅读(4689)  评论(3编辑  收藏  举报