Tensorflow机器学习入门——读取数据
TensorFlow 中可以通过三种方式读取数据:
一、通过feed_dict传递数据;
input1 = tf.placeholder(tf.float32) input2 = tf.placeholder(tf.float32) output = tf.multiply(input1, input2) with tf.Session() as sess: feed_dict={input1: [[7.,2.]], input2: [[2.],[3.]]} print(sess.run(output,feed_dict ))
二、从文件中读取数据;
import os import tensorflow as tf filename = ['A.jpg', 'B.jpg', 'C.jpg'] # string_input_producer会产生一个文件名队列 filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5) # reader从文件名队列中读数据。对应的方法是reader.read reader = tf.WholeFileReader() key, value = reader.read(filename_queue) init=tf.local_variables_initializer() # tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化 with tf.Session() as sess: sess.run(init) # 使用start_queue_runners之后,才会开始填充队列 tf.train.start_queue_runners(sess=sess) i = 0 while True: i += 1 # 获取图片数据并保存 image_data = sess.run(value) with open('read/test_%d.jpg' % i, 'wb') as f: f.write(image_data) # 程序最后会抛出一个OutOfRangeError,这是epoch跑完,队列关闭的标志
运行上面的代码需要做两点准备:
1.在python的工作目录下保存3张图片,分布命名为:'A.jpg', 'B.jpg', 'C.jpg'
2.在此目录下建立read文件夹
三、使用预加载的数据;