第十二节,TensorFlow读取数据的几种方法以及队列的使用
目录
TensorFlow程序读取数据一共有3种方法:
- 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
- 从文件读取数据: 在TensorFlow图的起始, 让一个输入管道从文件中读取数据。
- 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
一 预加载数据
import tensorflow as tf x1 = tf.constant([2,3,4]) x2 = tf.constant([4,0,1]) y = tf.add(x1,x2) with tf.Session() as sess: print(sess.run(y))
在这里使用x1,x2保存具体的值,即将数据直接内嵌到图中,再将图传入会话中执行,当数据量较大时,图的输出会遇到效率问题。
二 供给数据
import tensorflow as tf x1 = tf.placeholder(tf.int32) x2 = tf.placeholder(tf.int32) #用python产生数据 v1 = [2,3,4] v2 = [4,0,1] y = tf.add(x1,x2) with tf.Session() as sess: print(sess.run(y,feed_dict={x1:v1,x2:v2}))
在这里x1,x2只是占位符,没有具体的值,那么运行的时候去哪取值呢?这时候就要用到sess.run()的feed_dict参数,将python产生的数据传入,并计算y。
以上两种方法都很方便,但是遇到大型数据的时候就会很吃力,即使是Feed_dict,中间环节的增加也是不小的开销,因为数据量大的时候,TensorFlow程序运行的每一步,我们都需要使用python代码去从文件中读取数据,并对读取到的文件数据进行解码。最优的方案就是在图中定义好文件读取的方法,让TF自己从文件中读取数据,并解码成可用的样本集。
三 TensorFlow中的队列机制
从文件中读取数据的方法有很多,比如可以在一个文本里面写入图片数据的路径和标签,然后用tensorflow的read_file()读入图片;也可以将图片和标签的值直接存放在CSV或者txt文件。
我们会在后面陆续介绍以下几种读取文件的方式:
- 从字典结构的数据文件读取
- 从bin文件读取
- 从CSV(TXT)读取
- 从原图读取
- TFRecord格式文件的读取
在讲解文件的读取之前,我们需要先了解一下TensorFlow中的队列机制,后面也会详细介绍。
TensorFlow提供了一个队列机制,通过多线程将读取数据与计算数据分开。因为在处理海量数据集的训练时,无法把数据集一次全部载入到内存中,需要一边从硬盘中读取,一边进行训练,为了加快训练速度,我们可以采用多个线程读取数据,一个线程消耗数据。
下面简要介绍一下,TensorFlow里与Queue有关的概念和用法。详细内容点击原文。
其实概念只有三个:
Queue
是TF队列和缓存机制的实现QueueRunner
是TF中对操作Queue的线程的封装Coordinator
是TF中用来协调线程运行的工具
虽然它们经常同时出现,但这三样东西在TensorFlow里面是可以单独使用的,不妨先分开来看待。
1.Queue
据实现的方式不同,分成具体的几种类型,例如:
- tf.FIFOQueue :按入列顺序出列的队列
- tf.RandomShuffleQueue :随机顺序出列的队列
- tf.PaddingFIFOQueue :以固定长度批量出列的队列
- tf.PriorityQueue :带优先级出列的队列
- ... ...
这些类型的Queue除了自身的性质不太一样外,创建、使用的方法基本是相同的。
创建函数的参数:
tf.FIFOQueue(capacity, dtypes, shapes=None, names=None,
shared_name=None, name="fifo_queue")
#创建的图:一个先入先出队列,以及初始化,出队,+1,入队操作 q = tf.FIFOQueue(3, "float") init = q.enqueue_many(([0.1, 0.2, 0.3],)) x = q.dequeue() y = x + 1 q_inc = q.enqueue([y]) #开启一个session,session是会话,会话的潜在含义是状态保持,各种tensor的状态保持 with tf.Session() as sess: sess.run(init) for i in range(2): sess.run(q_inc) quelen = sess.run(q.size()) for i in range(quelen): print (sess.run(q.dequeue()))
2. QueueRunner
之前的例子中,入队操作都在主线程中进行,Session中可以多个线程一起运行。 在数据输入的应用场景中,入队操作从硬盘上读取,入队操作是从硬盘中读取输入,放到内存当中,速度较慢。 使用QueueRunner
可以创建一系列新的线程进行入队操作,让主线程继续使用数据。如果在训练神经网络的场景中,就是训练网络和读取数据是异步的,主线程在训练网络,另一个线程在将数据从硬盘读入内存。
''' QueueRunner()的使用 ''' q = tf.FIFOQueue(10, "float") counter = tf.Variable(0.0) #计数器 # 给计数器加一 increment_op = tf.assign_add(counter, 1.0) # 将计数器加入队列 enqueue_op = q.enqueue(counter) # 创建QueueRunner # 用多个线程向队列添加数据 # 这里实际创建了4个线程,两个增加计数,两个执行入队 qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2) #主线程 with tf.Session() as sess: sess.run(tf.initialize_all_variables()) #启动入队线程 enqueue_threads = qr.create_threads(sess, start=True) #主线程 for i in range(10): print (sess.run(q.dequeue()))
能正确输出结果,但是最后会报错,ERROR:tensorflow:Exception in QueueRunner: Session has been closed.也就是说,当循环结束后,该Session就会自动关闭,相当于main函数已经结束了。
''' QueueRunner()的使用 ''' q = tf.FIFOQueue(10, "float") counter = tf.Variable(0.0) #计数器 # 给计数器加一 increment_op = tf.assign_add(counter, 1.0) # 将计数器加入队列 enqueue_op = q.enqueue(counter) # 创建QueueRunner # 用多个线程向队列添加数据 # 这里实际创建了4个线程,两个增加计数,两个执行入队 qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2) ''' #主线程 with tf.Session() as sess: sess.run(tf.initialize_all_variables()) #启动入队线程 enqueue_threads = qr.create_threads(sess, start=True) #主线程 for i in range(10): print (sess.run(q.dequeue())) ''' # 主线程 sess = tf.Session() sess.run(tf.initialize_all_variables()) # 启动入队线程 enqueue_threads = qr.create_threads(sess, start=True) # 主线程 for i in range(0, 10): print(sess.run(q.dequeue()))
不使用with tf.Session,那么Session就不会自动关闭。
并不是我们设想的1,2,3,4,本质原因是增加计数的进程会不停的后台运行,执行入队的进程会先执行10次(因为队列长度只有10),然后主线程开始消费数据,当一部分数据消费被后,入队的进程又会开始执行。最终主线程消费完10个数据后停止,但其他线程继续运行,程序不会结束。
经验:因为tensorflow是在图上进行计算,要驱动一张图进行计算,必须要送入数据,如果说数据没有送进去,那么sess.run(),就无法执行,tf也不会主动报错,提示没有数据送进去,其实tf也不能主动报错,因为tf的训练过程和读取数据的过程其实是异步的。tf会一直挂起,等待数据准备好。现象就是tf的程序不报错,但是一直不动,跟挂起类似。
''' QueueRunner()的使用 ''' q = tf.FIFOQueue(10, "float") counter = tf.Variable(0.0) #计数器 # 给计数器加一 increment_op = tf.assign_add(counter, 1.0) # 将计数器加入队列 enqueue_op = q.enqueue(counter) # 创建QueueRunner # 用多个线程向队列添加数据 # 这里实际创建了4个线程,两个增加计数,两个执行入队 qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2) #主线程 with tf.Session() as sess: sess.run(tf.initialize_all_variables()) #启动入队线程 enqueue_threads = qr.create_threads(sess, start=True) #主线程 for i in range(10): print (sess.run(q.dequeue()))
上图将生成数据的线程注释掉,程序就会卡在sess.run(q.dequeue()),等待数据的到来QueueRunner是用来启动入队线程用的。
3.Coordinator
Coordinator是个用来保存线程组运行状态的协调器对象,它和TensorFlow的Queue没有必然关系,是可以单独和Python线程使用的。例如:
''' Coordinator ''' import threading, time # 子线程函数 def loop(coord, id): t = 0 while not coord.should_stop(): print(id) time.sleep(1) t += 1 # 只有1号线程调用request_stop方法 if (t >= 2 and id == 0): coord.request_stop() # 主线程 coord = tf.train.Coordinator() # 使用Python API创建10个线程 threads = [threading.Thread(target=loop, args=(coord, i)) for i in range(10)] # 启动所有线程,并等待线程结束 for t in threads: t.start() coord.join(threads)
将这个程序运行起来,会发现所有的子线程执行完两个周期后都会停止,主线程会等待所有子线程都停止后结束,从而使整个程序结束。由此可见,只要有任何一个线程调用了Coordinator的request_stop
方法,所有的线程都可以通过should_stop
方法感知并停止当前线程。
将QueueRunner和Coordinator一起使用,实际上就是封装了这个判断操作,从而使任何一个出现异常时,能够正常结束整个程序,同时主线程也可以直接调用request_stop
方法来停止所有子线程的执行。
简要 介绍完了TensorFlow中队列机制后,我们再来看一下如何从文件中读取数据。
四 从文件中读取数据
1.从字典结构的数据文件读取(python数据格式)
(1)在介绍字典结构的数据文件的读取之前,我们先来介绍怎么创建字典结构的数据文件。
- 先要准备好图片文件,我们使用Open CV3进行图像读取。
- 把cv2.imread()读取到的图像进行裁切,扭曲,等处理。
- 使用numpy才对数据进行处理,比如维度合并。
- 把处理好的每一张图像的数据和标签分别存放在对应的list(或者ndarray)中。
- 创建一个字典,包含两个元素‘data’和'labels',并分别赋值为上面的list。
- 使用pickle模块对字典进行序列化,并保存到文件中。
具体代码我们查看如下文章:图片存储为cifar的Python数据格式
如果针对图片比较多的情况,我们不太可能把所有图像都写入个文件,我们可以分批把图像写入几个文件中。
(2)cifar10数据有三种版本,分别是MATLAB,Python和bin版本 数据下载链接: http://www.cs.toronto.edu/~kriz/cifar.html
其中Python版本的数据即是以字典结构存储的数据 。
针对字典结构的数据文件读取,我在AlexNet那节中有详细介绍,主要就是通过pickle模块对文件进行反序列化,获取我们所需要的数据。
2.从bin文件读取
在官网的cifar的例子中就是从bin文件中读取的。bin文件需要以一定的size格式存储,比如每个样本的值占多少字节,label占多少字节,且这对于每个样本都是固定的,然后一个挨着一个存储。这样就可以使用tf.FixedLengthRecordReader 类来每次读取固定长度的字节,正好对应一个样本存储的字节(包括label)。并且用tf.decode_raw进行解析。
(1)制作bin file
如何将自己的图片存为bin file,可以看看下面这篇博客,这篇博客使用C++和opencv将图片存为二进制文件: http://blog.csdn.net/code_better/article/details/53289759
(2)从bin file读入
在后面会详细讲解如何从二进制记录文件中读取数据,并以cifar10_input.py作为案例。
def read_cifar10(filename_queue): """Reads and parses examples from CIFAR10 data files. Recommendation: if you want N-way read parallelism, call this function N times. This will give you N independent Readers reading different files & positions within those files, which will give better mixing of examples. Args: filename_queue: A queue of strings with the filenames to read from. Returns: An object representing a single example, with the following fields: height: number of rows in the result (32) width: number of columns in the result (32) depth: number of color channels in the result (3) key: a scalar string Tensor describing the filename & record number for this example. label: an int32 Tensor with the label in the range 0..9. uint8image: a [height, width, depth] uint8 Tensor with the image data """ class CIFAR10Record(object): pass result = CIFAR10Record() # Dimensions of the images in the CIFAR-10 dataset. # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the # input format. label_bytes = 1 # 2 for CIFAR-100 result.height = 32 result.width = 32 result.depth = 3 image_bytes = result.height * result.width * result.depth # Every record consists of a label followed by the image, with a # fixed number of bytes for each. record_bytes = label_bytes + image_bytes # Read a record, getting filenames from the filename_queue. No # header or footer in the CIFAR-10 format, so we leave header_bytes # and footer_bytes at their default of 0. reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) result.key, value = reader.read(filename_queue) # Convert from a string to a vector of uint8 that is record_bytes long. record_bytes = tf.decode_raw(value, tf.uint8) # The first bytes represent the label, which we convert from uint8->int32. result.label = tf.cast( tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32) # The remaining bytes after the label represent the image, which we reshape # from [depth * height * width] to [depth, height, width]. depth_major = tf.reshape( tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]), [result.depth, result.height, result.width]) # Convert from [depth, height, width] to [height, width, depth]. result.uint8image = tf.transpose(depth_major, [1, 2, 0]) return result
这段代码如果看不懂,可以先跳过。
亲爱的读者和支持者们,自动博客加入了打赏功能,陆陆续续收到了各位老铁的打赏。在此,我想由衷地感谢每一位对我们博客的支持和打赏。你们的慷慨与支持,是我们前行的动力与源泉。
日期 | 姓名 | 金额 |
---|---|---|
2023-09-06 | *源 | 19 |
2023-09-11 | *朝科 | 88 |
2023-09-21 | *号 | 5 |
2023-09-16 | *真 | 60 |
2023-10-26 | *通 | 9.9 |
2023-11-04 | *慎 | 0.66 |
2023-11-24 | *恩 | 0.01 |
2023-12-30 | I*B | 1 |
2024-01-28 | *兴 | 20 |
2024-02-01 | QYing | 20 |
2024-02-11 | *督 | 6 |
2024-02-18 | 一*x | 1 |
2024-02-20 | c*l | 18.88 |
2024-01-01 | *I | 5 |
2024-04-08 | *程 | 150 |
2024-04-18 | *超 | 20 |
2024-04-26 | .*V | 30 |
2024-05-08 | D*W | 5 |
2024-05-29 | *辉 | 20 |
2024-05-30 | *雄 | 10 |
2024-06-08 | *: | 10 |
2024-06-23 | 小狮子 | 666 |
2024-06-28 | *s | 6.66 |
2024-06-29 | *炼 | 1 |
2024-06-30 | *! | 1 |
2024-07-08 | *方 | 20 |
2024-07-18 | A*1 | 6.66 |
2024-07-31 | *北 | 12 |
2024-08-13 | *基 | 1 |
2024-08-23 | n*s | 2 |
2024-09-02 | *源 | 50 |
2024-09-04 | *J | 2 |
2024-09-06 | *强 | 8.8 |
2024-09-09 | *波 | 1 |
2024-09-10 | *口 | 1 |
2024-09-10 | *波 | 1 |
2024-09-12 | *波 | 10 |
2024-09-18 | *明 | 1.68 |
2024-09-26 | B*h | 10 |
2024-09-30 | 岁 | 10 |
2024-10-02 | M*i | 1 |
2024-10-14 | *朋 | 10 |
2024-10-22 | *海 | 10 |
2024-10-23 | *南 | 10 |
2024-10-26 | *节 | 6.66 |
2024-10-27 | *o | 5 |
2024-10-28 | W*F | 6.66 |
2024-10-29 | R*n | 6.66 |
2024-11-02 | *球 | 6 |
2024-11-021 | *鑫 | 6.66 |
2024-11-25 | *沙 | 5 |
2024-11-29 | C*n | 2.88 |

【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了