程序项目代做,有需求私信(小程序、网站、爬虫、电路板设计、驱动、应用程序开发、毕设疑难问题处理等)

第十二节,TensorFlow读取数据的几种方法以及队列的使用

TensorFlow程序读取数据一共有3种方法:

  • 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
  • 从文件读取数据: 在TensorFlow图的起始, 让一个输入管道从文件中读取数据。
  • 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

一 预加载数据

import tensorflow as tf
x1 = tf.constant([2,3,4])
x2 = tf.constant([4,0,1])

y = tf.add(x1,x2)

with tf.Session() as sess:
    print(sess.run(y))

在这里使用x1,x2保存具体的值,即将数据直接内嵌到图中,再将图传入会话中执行,当数据量较大时,图的输出会遇到效率问题。

二 供给数据

复制代码
import tensorflow as tf
x1 = tf.placeholder(tf.int32)
x2 = tf.placeholder(tf.int32)
#用python产生数据
v1 = [2,3,4]
v2 = [4,0,1]

y = tf.add(x1,x2)

with tf.Session() as sess:
    print(sess.run(y,feed_dict={x1:v1,x2:v2}))
复制代码

在这里x1,x2只是占位符,没有具体的值,那么运行的时候去哪取值呢?这时候就要用到sess.run()的feed_dict参数,将python产生的数据传入,并计算y。

以上两种方法都很方便,但是遇到大型数据的时候就会很吃力,即使是Feed_dict,中间环节的增加也是不小的开销,因为数据量大的时候,TensorFlow程序运行的每一步,我们都需要使用python代码去从文件中读取数据,并对读取到的文件数据进行解码。最优的方案就是在图中定义好文件读取的方法,让TF自己从文件中读取数据,并解码成可用的样本集。

三 TensorFlow中的队列机制

从文件中读取数据的方法有很多,比如可以在一个文本里面写入图片数据的路径和标签,然后用tensorflow的read_file()读入图片;也可以将图片和标签的值直接存放在CSV或者txt文件。

我们会在后面陆续介绍以下几种读取文件的方式:

  • 从字典结构的数据文件读取
  • 从bin文件读取
  • 从CSV(TXT)读取
  • 从原图读取
  • TFRecord格式文件的读取

在讲解文件的读取之前,我们需要先了解一下TensorFlow中的队列机制,后面也会详细介绍。

TensorFlow提供了一个队列机制,通过多线程将读取数据与计算数据分开。因为在处理海量数据集的训练时,无法把数据集一次全部载入到内存中,需要一边从硬盘中读取,一边进行训练,为了加快训练速度,我们可以采用多个线程读取数据,一个线程消耗数据。

下面简要介绍一下,TensorFlow里与Queue有关的概念和用法。详细内容点击原文

其实概念只有三个:

  • Queue是TF队列和缓存机制的实现
  • QueueRunner是TF中对操作Queue的线程的封装
  • Coordinator是TF中用来协调线程运行的工具

虽然它们经常同时出现,但这三样东西在TensorFlow里面是可以单独使用的,不妨先分开来看待。

1.Queue

据实现的方式不同,分成具体的几种类型,例如:

  • tf.FIFOQueue :按入列顺序出列的队列
  • tf.RandomShuffleQueue :随机顺序出列的队列
  • tf.PaddingFIFOQueue :以固定长度批量出列的队列
  • tf.PriorityQueue :带优先级出列的队列
  • ... ...

这些类型的Queue除了自身的性质不太一样外,创建、使用的方法基本是相同的。

创建函数的参数:

tf.FIFOQueue(capacity, dtypes, shapes=None, names=None,
               shared_name=None, name="fifo_queue")
Queue主要包含入列(enqueue)出列(dequeue)两个操作。队列本身也是图中的一个节点。其他节点(enqueue, dequeue)可以修改队列节点中的内容。enqueue操作返回计算图中的一个Operation节点,dequeue操作返回一个Tensor值。Tensor在创建时同样只是一个定义(或称为“声明”),需要放在Session中运行才能获得真正的数值。下面是一个单独使用Queue的例子:
复制代码
#创建的图:一个先入先出队列,以及初始化,出队,+1,入队操作  
q = tf.FIFOQueue(3, "float")  
init = q.enqueue_many(([0.1, 0.2, 0.3],))  
x = q.dequeue()  
y = x + 1  
q_inc = q.enqueue([y])  
  
#开启一个session,session是会话,会话的潜在含义是状态保持,各种tensor的状态保持  
with tf.Session() as sess:  
    sess.run(init)  
  
    for i in range(2):  
            sess.run(q_inc)    
    quelen =  sess.run(q.size())  
    
    for i in range(quelen):  
            print (sess.run(q.dequeue())) 
复制代码

2. QueueRunner

之前的例子中,入队操作都在主线程中进行,Session中可以多个线程一起运行。 在数据输入的应用场景中,入队操作从硬盘上读取,入队操作是从硬盘中读取输入,放到内存当中,速度较慢。 使用QueueRunner可以创建一系列新的线程进行入队操作,让主线程继续使用数据。如果在训练神经网络的场景中,就是训练网络和读取数据是异步的,主线程在训练网络,另一个线程在将数据从硬盘读入内存。

复制代码
'''
QueueRunner()的使用
'''
q = tf.FIFOQueue(10, "float")  
counter = tf.Variable(0.0)  #计数器
# 给计数器加一
increment_op = tf.assign_add(counter, 1.0)
# 将计数器加入队列
enqueue_op = q.enqueue(counter)

# 创建QueueRunner
# 用多个线程向队列添加数据
# 这里实际创建了4个线程,两个增加计数,两个执行入队
qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)


#主线程  
with tf.Session() as sess:  
    sess.run(tf.initialize_all_variables())  
    #启动入队线程  
    enqueue_threads = qr.create_threads(sess, start=True)  
    #主线程  
    for i in range(10):              
        print (sess.run(q.dequeue()))  
复制代码

能正确输出结果,但是最后会报错,ERROR:tensorflow:Exception in QueueRunner: Session has been closed.也就是说,当循环结束后,该Session就会自动关闭,相当于main函数已经结束了。

复制代码
'''
QueueRunner()的使用
'''
q = tf.FIFOQueue(10, "float")  
counter = tf.Variable(0.0)  #计数器
# 给计数器加一
increment_op = tf.assign_add(counter, 1.0)
# 将计数器加入队列
enqueue_op = q.enqueue(counter)

# 创建QueueRunner
# 用多个线程向队列添加数据
# 这里实际创建了4个线程,两个增加计数,两个执行入队
qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)

'''

#主线程  
with tf.Session() as sess:  
    sess.run(tf.initialize_all_variables())  
    #启动入队线程  
    enqueue_threads = qr.create_threads(sess, start=True)  
    #主线程  
    for i in range(10):              
        print (sess.run(q.dequeue()))  

'''


  
# 主线程  
sess = tf.Session()  
sess.run(tf.initialize_all_variables())  
  
# 启动入队线程  
enqueue_threads = qr.create_threads(sess, start=True) 
  
# 主线程  
for i in range(0, 10):  
    print(sess.run(q.dequeue()))  
复制代码

不使用with tf.Session,那么Session就不会自动关闭。

并不是我们设想的1,2,3,4,本质原因是增加计数的进程会不停的后台运行,执行入队的进程会先执行10次(因为队列长度只有10),然后主线程开始消费数据,当一部分数据消费被后,入队的进程又会开始执行。最终主线程消费完10个数据后停止,但其他线程继续运行,程序不会结束。

经验:因为tensorflow是在图上进行计算,要驱动一张图进行计算,必须要送入数据,如果说数据没有送进去,那么sess.run(),就无法执行,tf也不会主动报错,提示没有数据送进去,其实tf也不能主动报错,因为tf的训练过程和读取数据的过程其实是异步的。tf会一直挂起,等待数据准备好。现象就是tf的程序不报错,但是一直不动,跟挂起类似。 

复制代码
'''
QueueRunner()的使用
'''
q = tf.FIFOQueue(10, "float")  
counter = tf.Variable(0.0)  #计数器
# 给计数器加一
increment_op = tf.assign_add(counter, 1.0)
# 将计数器加入队列
enqueue_op = q.enqueue(counter)

# 创建QueueRunner
# 用多个线程向队列添加数据
# 这里实际创建了4个线程,两个增加计数,两个执行入队
qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)


#主线程  
with tf.Session() as sess:  
    sess.run(tf.initialize_all_variables())  
    #启动入队线程  
    enqueue_threads = qr.create_threads(sess, start=True)  
    #主线程  
    for i in range(10):              
        print (sess.run(q.dequeue()))  
复制代码

上图将生成数据的线程注释掉,程序就会卡在sess.run(q.dequeue()),等待数据的到来QueueRunner是用来启动入队线程用的。

3.Coordinator

Coordinator是个用来保存线程组运行状态的协调器对象,它和TensorFlow的Queue没有必然关系,是可以单独和Python线程使用的。例如:

复制代码
'''
Coordinator
'''
import threading, time

# 子线程函数
def loop(coord, id):
    t = 0
    while not coord.should_stop():
        print(id)
        time.sleep(1)
        t += 1
        # 只有1号线程调用request_stop方法
        if (t >= 2 and id == 0):
            coord.request_stop()

# 主线程
coord = tf.train.Coordinator()
# 使用Python API创建10个线程
threads = [threading.Thread(target=loop, args=(coord, i)) for i in range(10)]

# 启动所有线程,并等待线程结束
for t in threads: t.start()
coord.join(threads)
复制代码

将这个程序运行起来,会发现所有的子线程执行完两个周期后都会停止,主线程会等待所有子线程都停止后结束,从而使整个程序结束。由此可见,只要有任何一个线程调用了Coordinator的request_stop方法,所有的线程都可以通过should_stop方法感知并停止当前线程。

将QueueRunner和Coordinator一起使用,实际上就是封装了这个判断操作,从而使任何一个出现异常时,能够正常结束整个程序,同时主线程也可以直接调用request_stop方法来停止所有子线程的执行。

简要 介绍完了TensorFlow中队列机制后,我们再来看一下如何从文件中读取数据。

四 从文件中读取数据

1.从字典结构的数据文件读取(python数据格式)

(1)在介绍字典结构的数据文件的读取之前,我们先来介绍怎么创建字典结构的数据文件。

  • 先要准备好图片文件,我们使用Open CV3进行图像读取。
  • 把cv2.imread()读取到的图像进行裁切,扭曲,等处理。
  • 使用numpy才对数据进行处理,比如维度合并。
  • 把处理好的每一张图像的数据和标签分别存放在对应的list(或者ndarray)中。
  • 创建一个字典,包含两个元素‘data’和'labels',并分别赋值为上面的list。
  • 使用pickle模块对字典进行序列化,并保存到文件中。

具体代码我们查看如下文章:图片存储为cifar的Python数据格式

如果针对图片比较多的情况,我们不太可能把所有图像都写入个文件,我们可以分批把图像写入几个文件中。

(2)cifar10数据有三种版本,分别是MATLAB,Python和bin版本 数据下载链接: http://www.cs.toronto.edu/~kriz/cifar.html 

其中Python版本的数据即是以字典结构存储的数据 。

针对字典结构的数据文件读取,我在AlexNet那节中有详细介绍,主要就是通过pickle模块对文件进行反序列化,获取我们所需要的数据。

2.从bin文件读取

在官网的cifar的例子中就是从bin文件中读取的。bin文件需要以一定的size格式存储,比如每个样本的值占多少字节,label占多少字节,且这对于每个样本都是固定的,然后一个挨着一个存储。这样就可以使用tf.FixedLengthRecordReader 类来每次读取固定长度的字节,正好对应一个样本存储的字节(包括label)。并且用tf.decode_raw进行解析。

(1)制作bin file 

如何将自己的图片存为bin file,可以看看下面这篇博客,这篇博客使用C++和opencv将图片存为二进制文件: http://blog.csdn.net/code_better/article/details/53289759 

(2)从bin file读入 
在后面会详细讲解如何从二进制记录文件中读取数据,并以cifar10_input.py作为案例。

复制代码
def read_cifar10(filename_queue):
  """Reads and parses examples from CIFAR10 data files.

  Recommendation: if you want N-way read parallelism, call this function
  N times.  This will give you N independent Readers reading different
  files & positions within those files, which will give better mixing of
  examples.

  Args:
    filename_queue: A queue of strings with the filenames to read from.

  Returns:
    An object representing a single example, with the following fields:
      height: number of rows in the result (32)
      width: number of columns in the result (32)
      depth: number of color channels in the result (3)
      key: a scalar string Tensor describing the filename & record number
        for this example.
      label: an int32 Tensor with the label in the range 0..9.
      uint8image: a [height, width, depth] uint8 Tensor with the image data
  """

  class CIFAR10Record(object):
    pass
  result = CIFAR10Record()

  # Dimensions of the images in the CIFAR-10 dataset.
  # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
  # input format.
  label_bytes = 1  # 2 for CIFAR-100
  result.height = 32
  result.width = 32
  result.depth = 3
  image_bytes = result.height * result.width * result.depth
  # Every record consists of a label followed by the image, with a
  # fixed number of bytes for each.
  record_bytes = label_bytes + image_bytes

  # Read a record, getting filenames from the filename_queue.  No
  # header or footer in the CIFAR-10 format, so we leave header_bytes
  # and footer_bytes at their default of 0.
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  result.key, value = reader.read(filename_queue)

  # Convert from a string to a vector of uint8 that is record_bytes long.
  record_bytes = tf.decode_raw(value, tf.uint8)

  # The first bytes represent the label, which we convert from uint8->int32.
  result.label = tf.cast(
      tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
  depth_major = tf.reshape(
      tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
      [result.depth, result.height, result.width])
  # Convert from [depth, height, width] to [height, width, depth].
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])

  return result
复制代码

这段代码如果看不懂,可以先跳过。

亲爱的读者和支持者们,自动博客加入了打赏功能,陆陆续续收到了各位老铁的打赏。在此,我想由衷地感谢每一位对我们博客的支持和打赏。你们的慷慨与支持,是我们前行的动力与源泉。

日期姓名金额
2023-09-06*源19
2023-09-11*朝科88
2023-09-21*号5
2023-09-16*真60
2023-10-26*通9.9
2023-11-04*慎0.66
2023-11-24*恩0.01
2023-12-30I*B1
2024-01-28*兴20
2024-02-01QYing20
2024-02-11*督6
2024-02-18一*x1
2024-02-20c*l18.88
2024-01-01*I5
2024-04-08*程150
2024-04-18*超20
2024-04-26.*V30
2024-05-08D*W5
2024-05-29*辉20
2024-05-30*雄10
2024-06-08*:10
2024-06-23小狮子666
2024-06-28*s6.66
2024-06-29*炼1
2024-06-30*!1
2024-07-08*方20
2024-07-18A*16.66
2024-07-31*北12
2024-08-13*基1
2024-08-23n*s2
2024-09-02*源50
2024-09-04*J2
2024-09-06*强8.8
2024-09-09*波1
2024-09-10*口1
2024-09-10*波1
2024-09-12*波10
2024-09-18*明1.68
2024-09-26B*h10
2024-09-3010
2024-10-02M*i1
2024-10-14*朋10
2024-10-22*海10
2024-10-23*南10
2024-10-26*节6.66
2024-10-27*o5
2024-10-28W*F6.66
2024-10-29R*n6.66
2024-11-02*球6
2024-11-021*鑫6.66
2024-11-25*沙5
2024-11-29C*n2.88
posted @   大奥特曼打小怪兽  阅读(7869)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
如果有任何技术小问题,欢迎大家交流沟通,共同进步

公告 & 打赏

>>

欢迎打赏支持我 ^_^

最新公告

程序项目代做,有需求私信(小程序、网站、爬虫、电路板设计、驱动、应用程序开发、毕设疑难问题处理等)。

了解更多

点击右上角即可分享
微信分享提示