保存

 1 W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights")    #2行3列
 2 b = tf.Variable([[1,2,3],dtype=tf.float32,name="biases")
 3 #restore时,shape和dtype要一样才可以正确导入
 4 
 5 init = tf.initialize_all_variables()
 6 
 7 saver= tf.train.Saver()    #saver保存
 8 
 9 with tf.Session() as sess:
10     sess.run(init)
11     save_path = saver.save(sess,"xxx/xxxxx.ckpt")    #ckpt后缀
12     print("Save to path:",save_path)

 

导入

 1 #定义相同shape和dtype的变量
 2 W = tf.Variable(np.arange(6).reshpe((2,3)),dtype=tf.float32,name="weights")
 3 b = tf.Variable(np.arange(3).reshpe((1,3)),dtype=tf.float32,name="biases")
 4 
 5 #not need to init step
 6 saver = tf.train.Saver()
 7 with tf.Session as sess:
 8     saver.restore(sess,"xxx/xxxxx.ckpt")
 9     print("weights",sess.run(W))
10     print("biases",sess.run(b))

 

结果

 

posted on 2022-08-10 11:07  Jolyne123  阅读(13)  评论(0编辑  收藏  举报