Tensorflow同时加载使用多个模型
在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Session和想使用的模型不匹配导致的错误。而使用多个graph,就需要为每个graph使用不同的Session,但是每个graph也可以在多个Session中使用,这个时候就需要在每个Session使用的时候明确申明使用的graph。
g1 = tf.Graph() # 加载到Session 1的graph g2 = tf.Graph() # 加载到Session 2的graph sess1 = tf.Session(graph=g1) # Session1 sess2 = tf.Session(graph=g2) # Session2 # 加载第一个模型 with sess1.as_default(): with g1.as_default(): tf.global_variables_initializer().run() model_saver = tf.train.Saver(tf.global_variables()) model_ckpt = tf.train.get_checkpoint_state(“model1/save/path”) model_saver.restore(sess, model_ckpt.model_checkpoint_path) # 加载第二个模型 with sess2.as_default(): # 1 with g2.as_default(): tf.global_variables_initializer().run() model_saver = tf.train.Saver(tf.global_variables()) model_ckpt = tf.train.get_checkpoint_state(“model2/save/path”) model_saver.restore(sess, model_ckpt.model_checkpoint_path) ... # 使用的时候 with sess1.as_default(): with sess1.graph.as_default(): # 2 ... with sess2.as_default(): with sess2.graph.as_default(): ... # 关闭sess sess1.close() sess2.close()
注:1、在1处使用as_default使session在离开的时候并不关闭,在后面可以继续使用知道手动关闭;2、由于有多个graph,所以sess.graph与tf.get_default_value的值是不相等的,因此在进入sess的时候必须sess.graph.as_default()明确申明sess.graph为当前默认graph,否则就会报错。
PS:不同框架的模型(tf, caffe, torch等)在加载的很有可能导致底层的cuDNN分配出现问题从而报错,这种一般可以尝试通过模型的加载顺序来解决。
TensorFlow函数:tf.Session()和tf.Session().as_default()的区别
tf.Session().as_default():创建一个默认会话
那么问题来了,会话和默认会话有什么区别呢?TensorFlow会自动生成一个默认的计算图,如果没有特殊指定,运算会自动加入这个计算图中。TensorFlow中的会话也有类似的机制,但是TensorFlow不会自动生成默认的会话,而是需要手动指定。
tf.Session()创建一个会话,当上下文管理器退出时会话关闭和资源释放自动完成。
tf.Session().as_default()创建一个默认会话,当上下文管理器退出时会话没有关闭,还可以通过调用会话进行run()和eval()操作,代码示例如下:
tf.Session()代码示例: import tensorflow as tf a = tf.constant(1.0) b = tf.constant(2.0) with tf.Session() as sess: print(a.eval()) print(b.eval(session=sess))
运行结果如下: 1.0 RuntimeError: Attempted to use a closed Session.
在打印张量b的值时报错,报错为尝试使用一个已经关闭的会话。使用 tf.Session().as_default()不会有这个问题。
对于run()方法也是一样,如果想让默认会话在退出上下文管理器时关闭会话,可以调用sess.close()方法。
import tensorflow as tf a = tf.constant(1.0) b = tf.constant(2.0) with tf.Session().as_default() as sess: print(a.eval()) sess.close() print(b.eval(session=sess))
1.0
RuntimeError: Attempted to use a closed Session.
参考:
https://www.tensorflow.org/api_docs/python/tf/Session
https://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session
https://www.cnblogs.com/arkenstone/p/7016481.html