保存
1 W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights") #2行3列 2 b = tf.Variable([[1,2,3],dtype=tf.float32,name="biases") 3 #restore时,shape和dtype要一样才可以正确导入 4 5 init = tf.initialize_all_variables() 6 7 saver= tf.train.Saver() #saver保存 8 9 with tf.Session() as sess: 10 sess.run(init) 11 save_path = saver.save(sess,"xxx/xxxxx.ckpt") #ckpt后缀 12 print("Save to path:",save_path)
导入
1 #定义相同shape和dtype的变量 2 W = tf.Variable(np.arange(6).reshpe((2,3)),dtype=tf.float32,name="weights") 3 b = tf.Variable(np.arange(3).reshpe((1,3)),dtype=tf.float32,name="biases") 4 5 #not need to init step 6 saver = tf.train.Saver() 7 with tf.Session as sess: 8 saver.restore(sess,"xxx/xxxxx.ckpt") 9 print("weights",sess.run(W)) 10 print("biases",sess.run(b))
结果