TensorFlow加载部分模型

详情参考https://www.cnblogs.com/yibeimingyue/p/11921474.html

本文采用的方式为重写一样的graph, 然后恢复指定scope

1.保存模型部分,通过saver参数,定义要保存的scope:

variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  tf.get_collection能生成参与训练的scope值的列表,可以直接作为Saver的参数,可以按需通过列表切片指定scope.
variables = variables[8:]
 saver = tf.train.Saver(variables)  # create a saver
saver.save(sess,saver_path)

  2.重写一样的graph

class Rebuild(object):
    def __init__(self, batch_size):
        """
        build the graph
        """

  定义两个saver

saver_vgg = tf.train.Saver(vgg_ref_vars) # 这个是要恢复部分的saver
saver = tf.train.Saver() # 这个是当前新图的saver

  在实例化Rebuild类之后,初始化,然后restore

with tf.Session(config=config) as sess:
sess.run(init)
...
saver_vgg.restore(sess, vgg_graph_weight)#使用导入图的saver来恢复

  


  

 

posted @ 2021-12-30 16:14  ming_z  阅读(122)  评论(0编辑  收藏  举报