tensorflow学习--数据加载
文章主要来自Tensorflow官方文档,同时加入了自己的理解以及部分代码
数据读取
TensorFlow程序读取数据一共有3种方法:
- 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
- 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
- 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
目录
数据读取
- 供给数据(Feeding)
- 从文件读取数据
- 文件名, 乱序(shuffling), 和最大训练迭代数(epoch limits)
- 文件格式
- 预处理
- 批处理
- 使用QueueRunner创建预读线程
- 对记录进行过滤或者为每个纪录创建多个样本
- 序列化输入数据(Sparse input data)
- 预加载数据
- 多管线输入
供给数据
TensorFlow的数据供给机制允许你在TensorFlow运算图中将数据注入到任一张量中。因此,python运算可以把数据直接设置到TensorFlow图中。通过给run()或者eval()函数输入feed_dict参数, 可以启动运算过程。
with tf.Session():
input = tf.placeholder(tf.float32)
classifier = ...
print classifier.eval(feed_dict={input: my_python_preprocessing_fn()})
虽然你可以使用常量和变量来替换任何一个张量, 但是最好的做法应该是使用placeholder op节点。设计placeholder节点的唯一的意图就是为了提供数据供给(feeding)的方法。placeholder节点被声明的时候是未初始化的, 也不包含数据, 如果没有为它供给数据, 则TensorFlow运算的时候会产生错误, 所以千万不要忘了为placeholder提供数据。可以在tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py找到使用placeholder和MNIST训练的例子,MNIST tutorial也讲述了这一例子。
从文件读取数据
一共典型的文件读取管线会包含下面这些步骤:
- 文件名列表
- 可配置的 文件名乱序(shuffling)
- 可配置的 最大训练迭代数(epoch limit)
- 文件名队列
- 针对输入文件格式的阅读器
- 纪录解析器
- 可配置的 预处理器
- 样本队列
文件名, 乱序(shuffling), 和最大训练迭代数(epoch limits)
可以使用字符串张量(比如["file0", "file1"], [("file%d" % i) for i in range(2)], [("file%d" % i) for i in range(2)]) 或者tf.train.match_filenames_once 函数来产生文件名列表。
将文件名列表交给tf.train.string_input_producer 函数.string_input_producer来生成一个先入先出的队列, 文件阅读器会需要它来读取数据。
tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。
这个QueueRunner的工作线程是独立于文件阅读器的线程, 因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。
文件格式
根据你的文件格式,选择对应的文件阅读器,然后将文件名队列提供给阅读器的read方法。阅读器的read方法会输出一个key来表征输入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES']='3'
# Download Titanic dataset (in csv format).
filename_queue = tf.train.string_input_producer(["iris.csv"])
# Skip 1 line from the beginning of every file if needed
reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1.0], [1.0], [1.0], [1.0], ["1"]]
col1, col2, col3, col4, col5, col6 = tf.decode_csv(
value, record_defaults=record_defaults)
features = [col1, col2, col3, col4, col5]
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col6])
print(label)
coord.request_stop()
coord.join(threads)
每次read的执行都会从文件中读取一行内容, decode_csv 操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。
在调用run或者eval去执行read之前, 你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。
固定长度的记录
从二进制文件中读取固定长度纪录, 可以使用tf.FixedLengthRecordReader的tf.decode_raw操作。decode_raw操作可以讲一个字符串转换为一个uint8的张量。
举例来说,the CIFAR-10 dataset的文件格式定义是:每条记录的长度都是固定的,一个字节的标签,后面是3072字节的图像数据。uint8的张量的标准操作就可以从中获取图像片并且根据需要进行重组。 例子代码可以在tensorflow/models/image/cifar10/cifar10_input.py
找到,具体讲述可参见教程.
标准Tensorflow格式
另一种保存记录的方法可以允许你将任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
就是这样的一个例子。
从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocol buffer)解析为张量。 MNIST的例子就使用了convert_to_records 所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
, 你也可以将这个例子跟fully_connected_feed的版本加以比较。接下来用下面的例子解释如何构建tfrecord文件并从tfrecord文件中读取数据
filename_queue = tf.train.string_input_producer(["/home/learning/tensorflow/iris.csv"])
# Create TFRecords
# Generate Integer Features.
def build_int64_feature(data):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))
# Generate Float Features.
def build_float_feature(data):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[data]))
# Generate String Features.
def build_string_feature(data):
"""Returns a bytes_list from a string / byte."""
if isinstance(data, type(tf.constant(0))):
data = data.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
# Generate a TF `Example` parsing all features of the dataset
def convert_to_tfexample(no, sepal_length, sepal_width, petal_length, petal_width, species):
return tf.train.Example(
features=tf.train.Features(
feature={
'no': build_int64_feature(no),
'sepal_length': build_float_feature(sepal_length),
'sepal_width': build_float_feature(sepal_width),
'petal_length': build_float_feature(petal_length),
'petal_width': build_float_feature(petal_width),
'species': build_string_feature(species),
}
)
)
with open('/home/learning/tensorflow/iris.csv') as f:
with tf.python_io.TFRecordWriter('/home/learning/tensorflow/iris.tfrecord') as w:
# Generate a TF Example for all row in our dataset.
# CSV reader will read and parse all rows.
reader = csv.reader(f, skipinitialspace=True)
for i, record in enumerate(reader):
if i == 0:
continue
no, sepal_length, sepal_width, petal_length, petal_width, species = record
species = species.encode('utf-8')
# Parse each csv row to TF Example using the above functions.
example = convert_to_tfexample(int(no), float(sepal_length), float(sepal_width), float(petal_length),
float(petal_width), species)
# Serialize each TF Example to string, and write to TFRecord file
w.write(example.SerializeToString())
# Build features template, with types.
features = {
'no': tf.FixedLenFeature([], tf.int64),
'sepal_length': tf.FixedLenFeature([], tf.float32),
'sepal_width': tf.FixedLenFeature([], tf.float32),
'petal_length': tf.FixedLenFeature([], tf.float32),
'petal_width': tf.FixedLenFeature([], tf.float32),
'species': tf.FixedLenFeature([], tf.string),
}
# Create TensorFlow session.
sess = tf.Session()
# Load TFRecord data.
filenames = ["/home/zhangyiran/learning/tensorflow/iris.tfrecord"]
data = tf.data.TFRecordDataset(filenames)
# Parse features, using the above template.
def parse_record(record):
return tf.parse_single_example(record, features=features)
# Apply the parsing to each record from the dataset.
data = data.map(parse_record)
# Refill data indefinitely.
data = data.repeat()
# Shuffle data.
data = data.shuffle(buffer_size=1000)
# Batch data (aggregate records together).
data = data.batch(batch_size=4)
# Prefetch batch (pre-load batch for faster consumption).
data = data.prefetch(buffer_size=1)
# Create an iterator over the dataset.
iterator = data.make_initializable_iterator()
# Initialize the iterator.
sess.run(iterator.initializer)
# Get next data batch.
x = iterator.get_next()
# Dequeue data and display.
for i in range(3):
print(sess.run(x))
print("")
输出结果
{'no': array([141, 41, 88, 24]), 'petal_width': array([2.4, 0.3, 1.3, 0.5], dtype=float32), 'sepal_width': array([3.1, 3.5, 2.3, 3.3], dtype=float32), 'sepal_length': array([6.7, 5. , 6.3, 5.1], dtype=float32), 'petal_length': array([5.6, 1.3, 4.4, 1.7], dtype=float32), 'species': array([b'virginica', b'setosa', b'versicolor', b'setosa'], dtype=object)}
{'no': array([84, 56, 64, 35]), 'petal_width': array([1.6, 1.3, 1.4, 0.2], dtype=float32), 'sepal_width': array([2.7, 2.8, 2.9, 3.1], dtype=float32), 'sepal_length': array([6. , 5.7, 6.1, 4.9], dtype=float32), 'petal_length': array([5.1, 4.5, 4.7, 1.5], dtype=float32), 'species': array([b'versicolor', b'versicolor', b'versicolor', b'setosa'],
dtype=object)}
{'no': array([ 21, 144, 147, 119]), 'petal_width': array([0.2, 2.3, 1.9, 2.3], dtype=float32), 'sepal_width': array([3.4, 3.2, 2.5, 2.6], dtype=float32), 'sepal_length': array([5.4, 6.8, 6.3, 7.7], dtype=float32), 'petal_length': array([1.7, 5.9, 5. , 6.9], dtype=float32), 'species': array([b'setosa', b'virginica', b'virginica', b'virginica'], dtype=object)}
预处理
你可以对输入的样本进行任意的预处理, 这些预处理不依赖于训练参数, 你可以在tensorflow/models/image/cifar10/cifar10.py
找到数据归一化, 提取随机数据片,增加噪声或失真等等预处理的例子。
批处理
在数据输入管线的末端, 我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用tf.train.shuffle_batch函数
来对队列中的样本进行乱序处理. 该函数是先将队列中数据打乱,然后再从队列里读取出来,因此队列中剩下的数据也是乱序的.
import tensorflow as tf
import numpy as np
import os
def read_my_file_format(filename_queue):
reader = tf.TextLineReader()
key, value =reader.read(filename_queue)
record_defaults = [[1.0],[1.0],[1.0],[1.0],["1"]]
col2, col3, col4, col5, col6 = tf.decode_csv(
value, record_defaults = record_defaults)
return [col2, col3, col4, col5], col6
def input_pipeline(filenames, batch_size, num_epochs = None):
filename_queue = tf.train.string_input_producer(
file_path, num_epochs=num_epochs, shuffle = True)
features, label = read_my_file_format(filename_queue)
min_after_dequeue = 5
capacity = min_after_dequeue+3*batch_size
features_batch, label_batch = tf.train.shuffle_batch(
[features, label], batch_size = batch_size, capacity = capacity,
min_after_dequeue = min_after_dequeue)
return features_batch, label_batch
file_path = ["/home/Documents/data/iris.data"]
features_batch, label_batch = input_pipeline(file_path, 10)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
for _ in range(5):
feat, lb = sess.run([features_batch, label_batch])
print(feat,lb)
coord.request_stop()
coord.join(threads)
输出结果(部分)
[[4.8 3. 1.4 0.1]
[4.8 3.4 1.9 0.2]
[4.8 3.4 1.6 0.2]
[5.1 3.3 1.7 0.5]
[5. 3. 1.6 0.2]
[5.2 3.4 1.4 0.2]
[5.2 3.5 1.5 0.2]
[5. 3.4 1.6 0.4]
[5.4 3.4 1.5 0.4]
[4.9 3.1 1.5 0.1]] [b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa']
[[5.1 3.5 1.4 0.3]
[4.8 3.1 1.6 0.2]
[5.5 4.2 1.4 0.2]
[4.7 3.2 1.6 0.2]
[4.9 3.1 1.5 0.1]
[5. 3.2 1.2 0.2]
[4.4 3. 1.3 0.2]
[5.5 3.5 1.3 0.2]
[4.4 3.2 1.3 0.2]
[4.5 2.3 1.3 0.3]] [b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa']
如果你需要对不同文件中的样子有更强的乱序和并行处理,可以使用tf.train.shuffle_batch_join
函数. 示例:
def read_my_file_format(filename_queue):
# Same as above
def input_pipeline(filenames, batch_size, read_threads, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example_list = [read_my_file_format(filename_queue)
for _ in range(read_threads)]
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch_join(
example_list, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
在这个例子中, 你虽然只使用了一个文件名队列, 但是TensorFlow依然能保证多个文件阅读器从同一次迭代(epoch)的不同文件中读取数据,直到这次迭代的所有文件都被开始读取为止。(通常来说一个线程来对文件名队列进行填充的效率是足够的)
另一种替代方案是: 使用tf.train.shuffle_batch
函数,设置num_threads的值大于1, 使用多个线程在tensor_list中读取文件.这种方案可以保证同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件。这种方案的优点是:
- 避免了两个不同的线程从同一个文件中读取同一个样本。
- 避免了过多的磁盘搜索操作。
你一共需要多少个读取线程呢? 函数tf.train.shuffle_batch*
为TensorFlow图提供了获取文件名队列中的元素个数之和的方法. 如果你有足够多的读取线程, 文件名队列中的元素个数之和应该一直是一个略高于0的数。具体可以参考TensorBoard:可视化学习.
创建线程并使用QueueRunner
对象来预取
在我们的代码中tf.train.string_input_producer()
生成了文件名队列, 在TensorFlow中,队列不仅仅是一种数据结构,还是异步计算张量取值的一个重要机制。比如多个线程可以同时向一个队列中写元素,或者同时读取一个队列中的元素。TF提供了tf.Coordinator和tf.QueueRunner两个类来完成多线程协同的功能.从设计上这两个类必须被一起使用. Coordinator类是线程协调器, 用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:
- should_stop():如果线程应该停止则返回True。
- request_stop(
):请求该线程停止。 - join(
- ):等待被指定的线程终止。
QueueRunner是队列管理器,主要用于启动多个线程来操作同一个队列,启动的这些线程可以通过上面介绍的tf.Coordinator类来统一管理. QueueRunner会协调多个工作线程同时将多个张量推入同一个队列中.
在Python的训练程序中,创建一个QueueRunner来运行几个线程, 这几个线程处理样本,并且将样本推入队列. 创建一个Coordinator,让queue runner使用Coordinator来启动这些线程,创建一个训练的循环, 并且使用Coordinator来控制QueueRunner的线程们的终止, 如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。推荐的代码模板如下:
# Create the graph, etc.
init_op = tf.initialize_all_variables()
# Create a session for running operations in the Graph.
sess = tf.Session()
# Initialize the variables (like the epoch counter).
sess.run(init_op)
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
# Run training steps or whatever
sess.run(train_op)
except tf.errors.OutOfRangeError:
print 'Done training -- epoch limit reached'
finally:
# When done, ask the threads to stop.
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
sess.close()
疑问:这是怎么回事
首先,我们先创建数据流图,这个数据流图由一些流水线的阶段组成,阶段间用队列连接在一起。第一阶段将生成文件名,我们读取这些文件名并且把他们排到文件名队列中。第二阶段从文件中读取数据(使用Reader),产生样本,而且把样本放在一个样本队列中。根据你的设置,实际上也可以拷贝第二阶段的样本,使得他们相互独立,这样就可以从多个文件中并行读取。在第二阶段的最后是一个排队操作,就是入队到队列中去,在下一阶段出队。因为我们是要开始运行这些入队操作的线程,所以我们的训练循环会使得样本队列中的样本不断地出队
在tf.train中要创建这些队列和执行入队操作,就要添加tf.train.QueueRunner到一个使用tf.train.add_queue_runner函数的数据流图中。每个QueueRunner负责一个阶段,处理那些需要在线程中运行的入队操作的列表。一旦数据流图构造成功,tf.train.start_queue_runners
函数就会要求数据流图中每个QueueRunner去开始它的线程运行入队操作.
如果一切顺利的话,你现在可以执行你的训练步骤,同时队列也会被后台线程来填充。如果你设置了最大训练迭代数,在某些时候,样本出队的操作可能会得到一个tf.OutOfRangeError的错误。这其实是TensorFlow的“文件结束”(EOF) ———— 这就意味着已经达到了最大训练迭代数,已经没有更多可用的样本了.
最后一个因素是Coordinator。tf.train.Coordinator()
创建进线程协调器. 这是负责在收到任何关闭信号的时候,让所有的线程都知道。最常用的是在发生异常时这种情况就会呈现出来,比如说其中一个线程在运行某些操作时出现错误(或一个普通的Python异常).
想要了解更多的关于threading, queues, QueueRunners, and Coordinators的内容可以看这里.
疑问: 在达到最大训练迭代数的时候如何清理关闭线程?
想象一下,你有一个模型并且设置了最大训练迭代数。这意味着,生成文件的那个线程将只会在产生OutOfRange错误之前运行许多次。该QueueRunner会捕获该错误,并且关闭文件名的队列,最后退出线程。关闭队列做了两件事情:
- 如果还试着对文件名队列执行入队操作时将发生错误。任何线程不应该尝试去这样做,但是当队列因为其他错误而关闭时,这就会有用了。
- 任何当前或将来出队操作要么成功(如果队列中还有足够的元素)或立即失败(发生OutOfRange错误)。它们不会防止等待更多的元素被添加到队列中,因为上面的一点已经保证了这种情况不会发生。
关键是,当在文件名队列被关闭时候,有可能还有许多文件名在该队列中,这样下一阶段的流水线(包括reader和其它预处理)还可以继续运行一段时间。 一旦文件名队列空了之后,如果后面的流水线还要尝试从文件名队列中取出一个文件名(例如,从一个已经处理完文件的reader中),这将会触发OutOfRange错误。在这种情况下,即使你可能有一个QueueRunner关联着多个线程。如果这不是在QueueRunner中的最后那个线程,OutOfRange错误仅仅只会使得一个线程退出。这使得其他那些正处理自己的最后一个文件的线程继续运行,直至他们完成为止。 (但如果假设你使用的是tf.train.Coordinator,其他类型的错误将导致所有线程停止)。一旦所有的reader线程触发OutOfRange错误,然后才是下一个队列,再是样本队列被关闭。
同样,样本队列中会有一些已经入队的元素,所以样本训练将一直持续直到样本队列中再没有样本为止。如果样本队列是一个RandomShuffleQueue,因为你使用了shuffle_batch 或者 shuffle_batch_join,所以通常不会出现以往那种队列中的元素会比min_after_dequeue 定义的更少的情况。 然而,一旦该队列被关闭,min_after_dequeue设置的限定值将失效,最终队列将为空。在这一点来说,当实际训练线程尝试从样本队列中取出数据时,将会触发OutOfRange错误,然后训练线程会退出。一旦所有的培训线程完成,tf.train.Coordinator.join会返回,你就可以正常退出了。
筛选记录或产生每个记录的多个样本
举个例子,有形式为[x, y, z]的样本,我们可以生成一批形式为[batch, x, y, z]的样本。 如果你想滤除这个记录(或许不需要这样的设置),那么可以设置batch的大小为0;但如果你需要每个记录产生多个样本,那么batch的值可以大于1。 然后很简单,只需调用批处理函数(比如: shuffle_batch or shuffle_batch_join)去设置enqueue_many=True就可以实现。enqueue_many主要是设置tensor中的数据是否能重复,如果想要实现同一个样本多次出现可以将其设置为:“True”,如果只想要其出现一次,也就是保持数据的唯一性,这时候我们将其设置为默认值:“False”
稀疏输入数据
SparseTensors这种数据类型使用队列来处理不是太好。如果要使用SparseTensors你就必须在批处理之后使用tf.parse_example 去解析字符串记录 (而不是在批处理之前使用 tf.parse_single_example) 。
预取数据
这仅用于可以完全加载到存储器中的小的数据集。有两种方法:
- 存储在常数中。
- 存储在变量中,初始化后,永远不要改变它的值。
使用常数更简单一些,但是会使用更多的内存(因为常数会内联的存储在数据流图数据结构中,这个结构体可能会被复制几次)。
training_data = ...
training_labels = ...
with tf.Session():
input_data = tf.constant(training_data)
input_labels = tf.constant(training_labels)
...
要改为使用变量的方式,你就需要在数据流图建立后初始化这个变量。
training_data = ...
training_labels = ...
with tf.Session() as sess:
data_initializer = tf.placeholder(dtype=training_data.dtype,
shape=training_data.shape)
label_initializer = tf.placeholder(dtype=training_labels.dtype,
shape=training_labels.shape)
input_data = tf.Variable(data_initalizer, trainable=False, collections=[])
input_labels = tf.Variable(label_initalizer, trainable=False, collections=[])
...
sess.run(input_data.initializer,
feed_dict={data_initializer: training_data})
sess.run(input_labels.initializer,
feed_dict={label_initializer: training_lables})
设定trainable=False
可以防止该变量被数据流图的 GraphKeys.TRAINABLE_VARIABLES
收集, 这样我们就不会在训练的时候尝试更新它的值; 设定 collections=[]
可以防止GraphKeys.VARIABLES
收集后做为保存和恢复的中断点。
无论哪种方式,[tf.train.slice_input_producer function](http://www.tensorfly.cn/tfdoc/api_docs/python/io_ops.html#slice_input_producer)
函数可以被用来每次产生一个切片。这样就会让样本在整个迭代中被打乱,所以在使用批处理的时候不需要再次打乱样本。所以我们不使用shuffle_batch
函数,取而代之的是纯tf.train.batch(http://www.tensorfly.cn/tfdoc/api_docs/python/io_ops.html#batch)
函数。 如果要使用多个线程进行预处理,需要将num_threads参数设置为大于1的数字。
在tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py 中可以找到一个MNIST例子,使用常数来预加载。 另外使用变量来预加载的例子在tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py,你可以用上面 fully_connected_feed
和 fully_connected_reader
的描述来进行比较。
多输入管道
通常你会在一个数据集上面训练,然后在另外一个数据集上做评估计算(或称为 "eval")。 这样做的一种方法是,实际上包含两个独立的进程:
训练过程中读取输入数据,并定期将所有的训练的变量写入还原点文件)。
在计算过程中恢复还原点文件到一个推理模型中,读取有效的输入数据。
这两个进程在下面的例子中已经完成了:the example CIFAR-10 model,有以下几个好处:
eval被当做训练后变量的一个简单映射。
你甚至可以在训练完成和退出后执行eval。
你可以在同一个进程的相同的数据流图中有训练和eval,并分享他们的训练后的变量。参考the shared variables tutorial.