TensorFlow加载部分模型
详情参考https://www.cnblogs.com/yibeimingyue/p/11921474.html
本文采用的方式为重写一样的graph, 然后恢复指定scope
1.保存模型部分,通过saver参数,定义要保存的scope:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
tf.get_collection能生成参与训练的scope值的列表,可以直接作为Saver的参数,可以按需通过列表切片指定scope.
variables = variables[8:]
saver = tf.train.Saver(variables) # create a saver
saver.save(sess,saver_path)
2.重写一样的graph
class Rebuild(object): def __init__(self, batch_size): """ build the graph """
定义两个saver
saver_vgg = tf.train.Saver(vgg_ref_vars) # 这个是要恢复部分的saver saver = tf.train.Saver() # 这个是当前新图的saver
在实例化Rebuild类之后,初始化,然后restore
with tf.Session(config=config) as sess: sess.run(init) ... saver_vgg.restore(sess, vgg_graph_weight)#使用导入图的saver来恢复