解决tensorflow模型保存时Saver报错:TypeError: TF_SessionRun_wrapper: expected all values in input dict to be ndarray
TypeError: TF_SessionRun_wrapper: expected all values in input dict to be ndarray
import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' def myregression(): with tf.variable_scope("data"): x = tf.random_normal([100, 1], mean=1.75, stddev=0.5) y_true = tf.matmul(x, [[0.7]]) + 0.8 with tf.variable_scope("model"): # 权重 trainable 指定权重是否随着session改变 weight = tf.Variable(tf.random_normal([int(x.shape[1]), 1], mean=0, stddev=1), name="w") # 偏置项 bias = tf.Variable(0.0, name='b') # 构造y函数 y_predict = tf.matmul(x, weight) + bias with tf.variable_scope("loss"): # 定义损失函数 loss = tf.reduce_mean(tf.square(y_true - y_predict)) with tf.variable_scope("optimizer"): # 使用梯度下降进行求解 train_op = tf.train.GradientDescentOptimizer(0.1).minimize((loss)) # 1.收集tensor tf.summary.scalar("losses", loss) tf.summary.histogram("weights", weight) # 2.定义合并tensor的op merged = tf.summary.merge_all() # 定义一个保存模型的op saver = tf.train.Saver() with tf.Session() as sess: tf.global_variables_initializer().run() # import matplotlib.pyplot as plt # plt.scatter(x.eval(), y_true.eval()) # plt.show() print("初始化的权重:%f,偏置项:%f" % (weight.eval(), bias.eval())) # 建立事件文件 filewriter = tf.summary.FileWriter('./tmp/summary/test/', graph=sess.graph) n = 0 while loss.eval() > 1e-6: n += 1 sess.run(train_op) summary = sess.run(merged) filewriter.add_summary(summary, n) print("第%d次权重:%f,偏置项:%f" % (n, weight.eval(), bias.eval())) saver.save(sess, "tmp/ckpt/model") return weight, bias weight, bias = myregression() # x_min,x_max = np.min(x.eval()),np.max(x.eval()) # tx = np.arange(x_min,x_max,100)
I ran into the same issue. I don't think it is directly an issue with tf see In my case I had not changed anything in
tf but installed some other packages which reinstalled amongst other things numpy. The following fixed the issue for me
pip uninstall numpy # Keep repeating till all version of numpy are uninstalled
pip install numpy
pip uninstall numpy
pip install numpy
(加入上一步报错)pip install -U numpy