图像数据处理

图像的亮度、对比度等属性对图像的影响是非常大的,然而在很多图像识别问题中,这些因素都不应该影响最后的识别结果,所以在训练模型之前,需要对图像数据进行预处理,使训练得到的模型尽可能小地被无关因素影响。

7.1 TFRecord输入数据格式

7.1.1. TFRecord 格式介绍

7.1.2 TFRecord 样例程序

把mnist数据保存为tfrecord格式:

 1 #!coding:utf8
 2 
 3 import tensorflow as tf
 4 from tensorflow.examples.tutorials.mnist import input_data
 5 import numpy as np
 6 
 7 def _int64_feature(value):
 8     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 9 
10 def _bytes_feature(value):
11     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
12 
13 mnist = input_data.read_data_sets('D:\\files\\tf\mnist', one_hot=True)  # 确定label是由0、1组成的数组,还是单个整数。
14 images = mnist.train.images  # (55000, 784)
15 labels = mnist.train.labels
16 
17 pixels = images.shape[1]  # 784
18 num_examples = mnist.train.num_examples
19 
20 filename = 'D:\\files\\tf\yangxl.tfrecords'
21 writer = tf.python_io.TFRecordWriter(filename)
22 for index in range(num_examples):
23     image_raw = images[index].tostring()
24     example = tf.train.Example(features=tf.train.Features(feature={
25         'pixels': _int64_feature(pixels),
26         'label': _int64_feature(np.argmax(labels[index])),
27         'image_raw': _bytes_feature(image_raw)
28     }))
29     writer.write(example.SerializeToString())
30 writer.close()

 

读取tfrecord文件:

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 import numpy as np
 4 
 5 reader = tf.TFRecordReader()
 6 
 7 # 创建一个队列来维护输入文件列表
 8 file_queue = tf.train.string_input_producer(['/home/yangxl/files/mnist.tfrecords'])
 9 
10 # 从文件中读取一个样例; 一次性读取多个样例使用read_up_to函数
11 _, serialized_example = reader.read(file_queue)  # tensor
12 # 解析样例; 一次性解析多个样例使用parse_example函数
13 features = tf.parse_single_example(
14     serialized_example,
15     features={
16         # """
17         # tf提供了两种属性解析方法,一种是定长tf.FixedLenFeature,解析结果为一个Tensor;
18         # 另一种是变长tf.VarLenFeature,解析结果为SparseTensor,用于处理稀疏数据。
19         # 这里解析数据的格式需要和写入数据的格式一致。
20         # """
21         # 使用多行注释会报错:`KeyError: 'pixels'`, 我擦泪...
22 
23         # 解析时的键需要与保存时的键一致
24         'pixels': tf.FixedLenFeature([], tf.int64),
25         'label': tf.FixedLenFeature([], tf.int64),
26         'image_raw': tf.FixedLenFeature([], tf.string)
27     }
28 )
29 
30 # decode_raw可以把字符串解析为图像对应的像素数组
31 # cast转换数据类型
32 image = tf.decode_raw(features['image_raw'], tf.uint8)
33 label = tf.cast(features['label'], tf.int32)
34 pixels = tf.cast(features['pixels'], tf.int32)
35 
36 with tf.Session() as sess:
37     coord = tf.train.Coordinator()
38     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
39 
40     for i in range(15):  # 每次执行sess.run()都会从队列中取出一个样例,这样就会导致之后处理时可能不是同一个样例,没注意这个问题,这两天就卡在这上面了
41         image_value, label_value, pixels_value = sess.run([image, label, pixels])
42         print(label_value)
43 
44         # 可视化, 可视化之前需要把一维数组转为二维数组
45         image_value = np.reshape(image_value, [28, 28])
46         plt.imshow(image_value)
47         plt.show()

 

7.2 图像数据处理

一张RGB图像可以看成一个三维矩阵,矩阵中的每个数字表示图像上不同位置、不同颜色的亮度。图像在存储时,并不是直接记录这些矩阵中的数字,而是记录经过压缩编码之后的结果。所以要将一张图像还原成一个三维矩阵,需要解码的过程。TF提供了对jpeg、png格式图像的编码/解码函数。

图像编码、解码

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 
 4 with tf.gfile.GFile('/home/error/cat.jpg', 'rb') as f:
 5     image_raw_data = f.read()
 6 # 对jpeg格式的图像进行解码, 得到图像对应的三维矩阵, 得到一个tensor
 7 image_data = tf.image.decode_jpeg(image_raw_data)  # 得到一个tensor, (1797, 2673, 3)  dtype=uint8
 8 
 9 # 编码
10 encoded_image = tf.image.encode_jpeg(image_data)  # 得到一个tensor
11 
12 
13 with tf.Session() as sess:
14     plt.imshow(sess.run(image_data))
15     plt.show()
16 
17     with tf.gfile.GFile('/home/error/cat_bk.jpg', 'wb') as f:
18         f.write(sess.run(encoded_image))

 

图像大小调整

图像大小是不固定的,但神经网络输入节点的个数是固定的,所以在将图像的像素作为输入提供给神经网络之前,需要先将图像的大小统一。

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 
 4 with tf.gfile.FastGFile('/home/error/cat.jpg', 'rb') as f:
 5     image_raw_data = f.read()
 6 # 对jpeg格式的图像进行解码, 得到图像对应的三维矩阵, 得到一个tensor
 7 image_data = tf.image.decode_jpeg(image_raw_data)  # (1797, 2673, 3)
 8 
 9 # 在图像处理之前将图像由uint8转为实数类型
10 img_data = tf.image.convert_image_dtype(image_data, dtype=tf.float32)
11 
12 resized_image = tf.image.resize_images(img_data, [300, 300], method=0)  # method取值为0~3
13 
14 with tf.Session() as sess:
15     print(resized_image)
16 
17     plt.imshow(sess.run(resized_image))
18     plt.show()

 

裁剪和填充,居中

1 # 裁剪, 如果原始图像的尺寸大于目标图像, 会自动截取原始图像居中的部分
2 croped = tf.image.resize_image_with_crop_or_pad(img_data, 1000, 1000)
3 # 填充, 如果目标图像的尺寸大于原始图像, 会自动在原始图像的四周填充全0背景
4 paded = tf.image.resize_image_with_crop_or_pad(img_data, 3000, 3000)

按比例裁剪,居中

1 # 按比例截取原始图像居中的部分, 比例为(0, 1]之间的实数
2 central_cropped = tf.image.central_crop(img_data, 0.5)

在指定区域进行裁剪和填充

1 # 裁剪给定区域的图像, 该函数对给出的尺寸有一定的要求, 否则报错
2 croped_bound = tf.image.crop_to_bounding_box(img_data, 500, 500, 800, 800)
3 
4 # 填充, 图像从(500, 500)开始, 左侧和上侧全0背景, 显示图像后, 继续是全0背景。该函数对给出的尺寸有一定的要求, 否则报错
5 paded_bound = tf.image.pad_to_bounding_box(img_data, 500, 500, 2500, 3500)  # offset_height + img_height < target_height

 

图像翻转

图像的翻转不应该影响识别的效果,因此在训练图像识别神经网络时,可以随机地翻转训练图像,这样训练得到的模型可以识别不同角度的实体。

 1 # 上下翻转
 2 flipped = tf.image.flip_up_down(img_data)
 3 # 左右翻转
 4 flipped = tf.image.flip_left_right(img_data)
 5 # 沿对角线翻转, 主对角线
 6 flipped = tf.image.transpose_image(img_data)
 7 
 8 # 以50%的概率上下翻转图像
 9 flipped = tf.image.random_flip_up_down(img_data)
10 # 以50%的概率左右翻转图像
11 flipped = tf.image.random_flip_left_right(img_data)

图像色彩的调整

调整图像的亮度、对比度、饱和度和色相都不会影响识别结果,因此可以随机地调整这些属性

调整亮度

 1 # 调整亮度, 负号是调暗, -1为黑屏, 正数是调亮, 1为白屏
 2 adjusted = tf.image.adjust_brightness(img_data, -0.5)
 3 
 4 # 在[-max_delta, max_delta]范围内随机调整图像亮度
 5 adjusted = tf.image.random_brightness(img_data, 1)
 6 
 7 
 8 # 截断
 9 # 色彩调整的API可能导致像素的实数值超出0.0~1.0的范围,因此在最终输出图像前需要将其截断在0.0~1.0范围内,
10 # 否则不仅图像不能正常可视化,以此为输入的神经网络的训练质量也可能会受到影响。
11 # 如果对图像有多项处理,那么截断应该在所有处理完成之后进行
12 adjusted = tf.clip_by_value(adjusted, 0.0, 1.0)

调整对比度

1 # 调整对比度,将对比度减少到0.5倍
2 adjusted = tf.image.adjust_contrast(img_data, 0.5)
3 # 调整对比度,将对比度增加5倍
4 adjusted = tf.image.adjust_contrast(img_data, 5)
5 # 将对比度在[0.5, 5]范围内随机调整
6 adjusted = tf.image.random_contrast(img_data, 0.5, 5)

调整色相

1 # 分别取值[0.1, 0.3, 0.6, 0.9], 色彩从绿变为蓝,又变为红
2 adjusted = tf.image.adjust_hue(img_data, 0.9)
3 # 取值在[0.0, 0.5]之前
4 adjusted = tf.image.random_hue(0, 0.8)

调整饱和度

1 # 调整饱和度
2 adjusted = tf.image.adjust_saturation(img_data, -5)  # 饱和度-5(+5就是加5)
3 # 在[-5, 5]范围内随机调整饱和度
4 tf.image.random_saturation(img_data, -5, 5)

注意:对于色相、饱和度,需要输入数据的channels为3,例如mnist数据就不行,亮度、对比度没有限制。

 

将图像标准化

即将图像的亮度均值变为0,方差变为1

1 adjusted = tf.image.per_image_standardization(img_data)

处理标注框

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 
 4 with tf.gfile.GFile('/home/error/cat.jpg', 'rb') as f:
 5     image_raw_data = f.read()
 6 # 对jpeg格式的图像进行解码, 得到图像对应的三维矩阵, 得到一个tensor
 7 decoded_image_data = tf.image.decode_jpeg(image_raw_data)  # (1797, 2673, 3)
 8 
 9 # 在图像处理之前将图像由uint8转为实数类型
10 converted_image_data = tf.image.convert_image_dtype(decoded_image_data, dtype=tf.float32)
11 
12 # 把图像缩小一些,让标注框更清楚
13 resized_image_data = tf.image.resize_images(converted_image_data, [180, 267])
14 
15 # 输入是一个batch的数据,也就是多张图片组成的四维矩阵,所以需要加1个维度
16 expanded_image_data = tf.expand_dims(resized_image_data, 0)  # (1, 180, 267, ?)
17 # 标注框,数值为比例,秩为3,设置了两个标注框,为啥秩为3呢??
18 boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]])
19 drawn_image_data = tf.image.draw_bounding_boxes(expanded_image_data, boxes)
20 
21 adjusted = tf.clip_by_value(drawn_image_data[0], 0.0, 1.0)
22 
23 with tf.Session() as sess:
26 
27     plt.imshow(sess.run(adjusted))
28     plt.show()

随机截取图像

随机截取图像上有信息含量的部分也是一种提高模型健壮性的方式。这样可以使训练得到的模型不受识别物体大小的影响。

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 
 4 with tf.gfile.FastGFile('/home/yangxl/files/cat.jpg', 'rb') as f:
 5     image_raw_data = f.read()
 6 # 对jpeg格式的图像进行解码, 得到图像对应的三维矩阵, 得到一个tensor
 7 decoded_image_data = tf.image.decode_jpeg(image_raw_data)  # (1797, 2673, 3)
 8 # 在图像处理之前将图像由uint8转为实数类型
 9 converted_image_data = tf.image.convert_image_dtype(decoded_image_data, dtype=tf.float32)
10 # 把图像缩小一些,让标注框更清楚
11 resized_image_data = tf.image.resize_images(converted_image_data, [180, 267], method=1)
12 
13 # 输入是一个batch的数据,也就是多张图片组成的四维矩阵,所以需要加1个维度
14 expanded_image_data = tf.expand_dims(resized_image_data, 0)  # (1, 180, 267, ?)
15 # 标注框,数值为比例,秩为3,设置了两个标注框
16 boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]])
17 
18 # 扩维之前的shape
19 begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
20     tf.shape(resized_image_data), bounding_boxes=boxes,
21     min_object_covered=0.4
22 )
23 image_with_box = tf.image.draw_bounding_boxes(expanded_image_data, bbox_for_draw)
24 
25 adjusted = tf.clip_by_value(image_with_box[0], 0.0, 1.0)
26 distorted_image = tf.slice(adjusted, begin, size)
27 
28 with tf.Session() as sess:
29     plt.imshow(sess.run(distorted_image))  # 如果不使用slice, 像这样plt.imshow(sess.run(adjusted)), 可视化结果为不截取只随机标注
30     plt.show()
31 # 其实就多了两行:
32 # sample_distorted_bounding_box和slice

 

7.2.2 图像预处理完整样例

因为调整亮度、对比度、饱和度和色相的顺序会影响最后得到的结果,所以可以定义多种不同的顺序。具体使用哪一种可以随机选定。这样可以进一步降低无关因素对模型的影响。

 1 import tensorflow as tf
 2 import numpy as np
 3 import matplotlib.pyplot as plt
 4 
 5 def distort_color(image, color_ordering=0):
 6     if color_ordering == 0:
 7         image = tf.image.random_brightness(image, max_delta=32./255.)
 8         image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
 9         image = tf.image.random_hue(image, max_delta=0.2)
10         image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
11     elif color_ordering == 1:
12         image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
13         image = tf.image.random_brightness(image, max_delta=32. / 255.)
14         image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
15         image = tf.image.random_hue(image, max_delta=0.2)
16     return tf.clip_by_value(image, 0.0, 1.0)
17 
18 # 给定一张解码后的图像、目标图像的尺寸以及图像上的标注框,此函数可以对给出的图像进行预处理。
19 # 这个函数的输入图像是图像识别问题中的原始训练数据,而输出是神经网络模型的输入层。
20 # 注意,只处理模型的训练数据,对于预测数据,一般不需要使用随机变换的步骤。
21 def preprocessed_for_train(image, height, width, bbox):
22     # 如果没有提供标注框,则认为整个图像就是需要关注的部分。
23     if bbox is None:
24         bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])  # 秩为3
25 
26     # 转换图像张量的类型
27     if image.dtype != tf.float32:
28         image = tf.image.convert_image_dtype(image, tf.float32)
29 
30     # 随机截取图像,减少需要关注的物体大小对图像识别算法的影响。
31     bbox_begin, bbox_size, bbox_for_draw = tf.image.sample_distorted_bounding_box(tf.shape(image), bounding_boxes=bbox)
32     distorted_image = tf.slice(image, bbox_begin, bbox_size)
33 
34     # 调整大小
35     distorted_image = tf.image.resize_images(distorted_image, [height, width], method=np.random.randint(4))
36 
37     # 随机翻转
38     distorted_image = tf.image.random_flip_left_right(distorted_image)
39 
40     # 调整色彩
41     distorted_image = distort_color(distorted_image, np.random.randint(2))
42 
43     return distorted_image
44 
45 
46 image_raw_data = tf.gfile.GFile('/home/yangxl/files/cat.jpg', 'rb').read()
47 with tf.Session() as sess:
48     img_data = tf.image.decode_jpeg(image_raw_data)
49     boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]])
50 
51     # 运行6次获得6中不同的图像
52     for i in range(6):
53         result = preprocessed_for_train(img_data, 299, 299, boxes)
54         plt.imshow(result.eval())
55         plt.show()

 

7.3 多线程输入数据处理框架

虽然图像预处理方法可以减小无关因素对图像识别模型效果的影响,但是这些复杂的预处理过程也会减慢整个训练过程。为了避免图像预处理成为神经网络模型训练效率的瓶颈,tensorflow提供了一套多线程处理输入数据的框架。

tensorflow中,队列不仅是一种数据结构,更提供了多线程机制。队列也是tensorflow多线程输入数据处理框架的基础。

7.3.1介绍队列和多线程

7.3.2介绍前三步

7.3.3介绍最后一步

7.3.4完整示例

7.4介绍数据集

 

7.3.1 队列与多线程

队列和变量类似,都是计算图上有状态的节点。其他状态节点可以修改它们的状态。

队列操作:

 1 import tensorflow as tf
 2 
 3 # 先进先出队列
 4 q = tf.FIFOQueue(2, 'int32')
 5 # 队列初始化
 6 init = q.enqueue_many([[0, 10], ])  # 这个至少要有两层括号,否则报错:Shape () must have rank at least 1
 7 
 8 x = q.dequeue()
 9 y = x + 1
10 q_inc = q.enqueue([y])  # 可以没有括号
11 
12 with tf.Session() as sess:
13     init.run()  # 队列初始化需要明确调用
14     for i in range(5):
15         # 10, 1   1, 11   11, 2   2, 12   12, 3
16         sess.run(q_inc)
17 
18     print(sess.run(x))  # 12
19     print(sess.run(x))  # 3

tf提供了FIFOQueue和RandomShuffleQueue两种队列。FIFOQueue是先进先出队列,RandomShuffleQueue会将队列中的元素打乱,每次出队列操作得到的是从当前队列所有元素中随机选择的一个。在训练神经网络时希望每次使用的训练数据尽量随机,RandomShuffleQueue就提供了这样的功能。

tf提供了tf.train.Coordinator和tf.QueueRunner两个类来完成多线程协同的功能。

tf.train.Coordinator主要用于协同多个线程一起停止,并提供了should_stop、request_stop和join三个函数。在启动线程之前,需要先声明一个tf.train.Coordinator类,并将这个类传入每个创建的线程中。启动的线程需要一直查询tf.Coordinator类中提供的should_stop函数,当这个函数的返回值为True时,则当前线程退出。每个线程都可以通过调用request_stop函数来通知其他线程退出,即当某一个线程调用request_stop函数之后,should_stop函数的返回值被设置为True,这样其他线程就可以同时退出了。

 1 import tensorflow as tf
 2 import numpy as np
 3 import threading
 4 import time
 5 
 6 def MyLoop(coord, worker_id):
 7     while not coord.should_stop():
 8         if np.random.rand() < 0.05:
 9             print('Stoping from id: %d\n' % worker_id)
10             coord.request_stop()
11         else:
12             print('working on id: %d\n' % worker_id)
13         time.sleep(1)
14 
15 
16 coord = tf.train.Coordinator()
17 threads = [threading.Thread(target=MyLoop, args=(coord, i)) for i in range(5)]
18 
19 for t in threads:
20     t.start()
21 
22 # 等待所有线程退出
23 coord.join(threads)

tf.train.QueueRunner主要用于启动多个线程来操作同一个队列。启动的线程可以通过tf.Coordinator类来统一管理。

 1 queue = tf.FIFOQueue(100, 'float')
 2 enqueue_op = queue.enqueue([tf.random_normal([1])])
 3 
 4 # 启动5个线程来操作队列,每个线程中运行的是enqueue_op
 5 qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
 6 # 将qr加入到计算图指定的集合中,如果没有指定集合则默认加到tf.GraphKeys.QUEUE_RUNNERS
 7 tf.train.add_queue_runner(qr)
 8 
 9 out_tensor = queue.dequeue()
10 
11 with tf.Session() as sess:
12     coord = tf.train.Coordinator()
13     # 使用tf.train.QueueRunner时,需要明确调用tf.train.start_queue_runners来启动所有线程。
14     # tf.train.start_queue_runners会默认启动tf.GraphKeys.QUEUE_RUNNERS集合中的所有QueueRunner。
15     # 因为这个函数只支持启动指定集合中的QueueRunner,所以tf.train.add_queue_runner和tf.train.start_queue_runners会指定同一个集合。
16     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
17 
18     for i in range(3):
19         print(sess.run(out_tensor))
20 
21     coord.request_stop()
22     coord.join(threads)

 

7.3.2 输入文件队列

使用TF中的队列管理输入文件列表。

虽然一个TFRecord文件中可以保存多个训练样例,但是当训练数据量较大时,可以将数据分成多个TFRecord文件来提高处理效率。tensorflow提供了tf.train.match_filenames_once函数来获取符合一个正则表达式的所有文件,得到的文件列表可以通过tf.train.string_input_producer函数进行有效的管理。注意,在使用tf.train.match_filenames_once时需要初始化一些变量,tf.local_variables_initizliaer().run()

tf.train.string_input_producer函数会使用初始化时提供的文件列表创建一个输入队列,输入队列中原始的元素为文件列表中的所有文件,创建好的输入队列可以作为文件读取函数的参数。

每次调用文件读取函数时,该函数会先判断当前是否已有打开的文件可读,如果没有或者打开的文件已经读完,这个函数就会从输入队列中出队一个文件并从这个文件中读取数据。

1 reader = tf.TFRecordReader()
2 # 创建输入队列
3 filename_queue = tf.train.string_input_producer(['/home/error/output.tfrecords'])
4 # 读取样例
5 _, serializd_example = reader.read(filename_queue)

通过设置shuffle参数,tf.train.string_input_producer函数支持随机打乱文件列表中文件出队的顺序。随机打乱文件顺序以及加入输入队列的过程会跑在一个单独的线程上,这样不会影响获取文件的速度。

tf.train.string_input_producer生成的输入队列可以同时被多个文件读取线程操作,而且输入队列会将队列中的文件均匀地分配给不同的线程,不出现有些文件被处理多次而有些文件还没被处理的情况。

当一个输入队列中的所有文件都被处理后,它会将初始化时提供的文件列表中的文件全部重新加入队列。可以通过设置num_epochs参数来限制加载初始文件列表的最大轮数。当所有文件都已经被使用了设定的轮数后,如果继续尝试读取新的文件,输入队列会报错:OutOfRange。在测试神经网络时,因为所有测试数据只需要使用一次,所有可以将num_epochs设置为1,这样在计算完一轮之后程序将自动停止。

在展示tf.train.match_filenames_once和tf.train.string_input_producer函数的使用方法之前,先生成两个TFRecords文件,

 1 num_shards = 2
 2 instances_per_shard = 2
 3 
 4 def _int64_feature(value):
 5     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 6 
 7 # 生成两个文件,每个文件保存2个样例
 8 for i in range(num_shards):
 9     filename = '/home/yangxl/files/tfrecords/data.tfrecords-%.5d-of-%.5d' % (i, num_shards)  # 书上是带括号的,('...')
10     writer = tf.python_io.TFRecordWriter(filename)
11 
12     for j in range(instances_per_shard):
13         example = tf.train.Example(features=tf.train.Features(feature={
14             'i': _int64_feature(i),
15             'j': _int64_feature(j)
16         }))
17         writer.write(example.SerializeToString())
18     writer.close()

读取多个TFRecord文件,获取样例数据,

 1 files_list = tf.train.match_filenames_once('/home/error/tfrecord/data.tfrecords-*')  # 参数为正则表达式
 2 filename_queue = tf.train.string_input_producer(files_list, num_epochs=2, shuffle=True)
 3 
 4 reader = tf.TFRecordReader()
 5 _, serialized_example = reader.read(filename_queue)
 6 features = tf.parse_single_example(
 7     serialized_example,
 8     features={
 9         'i': tf.FixedLenFeature([], tf.int64),
10         'j': tf.FixedLenFeature([], tf.int64)
11     }
12 )
13 
14 with tf.Session() as sess:
15     # 虽然在本段程序中没有声明任何变量,但是使用tf.train.match_filenames_once函数时需要初始化一些变量
16     tf.local_variables_initializer().run()
17 
18     print(sess.run(files_list))
19 
20     coord = tf.train.Coordinator()
    # tf.train.string_input_producer创建文件队列也是调用了FIFOQueue、enqueue_many、QueueRunner、add_queue_runner这几个操作,所以需要明确调用启动线程的语句。
21 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 22 23 for i in range(6): 24 print(sess.run([features['i'], features['j']])) 25 26 coord.request_stop() 27 coord.join(threads)

 

7.3.3 组合训练数据(batching)

从文件列表中读取单个样例,将单个样例进行预处理,将经过预处理的单个样例组织成batch,提供给神经网络输入层。tensorflow提供了tf.train.batchtf.train.shuffle_batch函数来将单个的样例组织成batch形式输出。这两个函数都会生成一个队列,队列的入队操作是生成单个样例的方法,而每次出队得到的是一个batch的样例,二者唯一的区别在于是否将数据顺序打乱。

tf.train.batchtf.train.shuffle_batch的使用方法,

 1 files_list = tf.train.match_filenames_once('/home/error/tfrecord/data.tfrecords-*')  # 参数为正则表达式
 2 filename_queue = tf.train.string_input_producer(files_list, shuffle=False)
 3 
 4 reader = tf.TFRecordReader()
 5 _, serialized_example = reader.read(filename_queue)
 6 features = tf.parse_single_example(
 7     serialized_example,
 8     features={
 9         'i': tf.FixedLenFeature([], tf.int64),
10         'j': tf.FixedLenFeature([], tf.int64)
11     }
12 )
13 example, label = features['i'], features['j']
14 
15 batch_size = 5
16 # 队列中最多可以存储的样例个数。一般来说,队列的大小与每个batch的大小相关。
17 capacity = 1000 + 3 * batch_size
18 
19 
20 # 使用batch来组合样例。
# capacity给出了队列的最大容量,当队列长度等于容量时,tensorflow暂停入队操作,而只是等待元素出队;当队列长度小于容量时,tensorflow自动重新启动入队操作。
21 example_batch, label_batch = tf.train.batch( 22 [example, label], batch_size=batch_size, capacity=capacity 23 ) 24 25 with tf.Session() as sess: 26 # 虽然在本段程序中没有声明任何变量,但是使用tf.train.match_filenames_once函数时需要初始化一些变量 27 tf.local_variables_initializer().run() 28 print(sess.run(files_list)) 29 30 coord = tf.train.Coordinator() 31 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 32 33 for i in range(3): 34 cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch]) 35 print(cur_example_batch, cur_label_batch) 36 37 coord.request_stop() 38 coord.join(threads)

tf.train.batch和tf.train.shuffle_batch的区别在于,shuffle_batch多一个参数min_after_dequeue,限制了出队时队列中元素的最少个数。当队列中元素太少时,随机打乱样例顺序的作用就不大了。当队列中元素不够时,出队操作将等待更多的元素入队才会完成。

# min_after_dequeue参数限制了出队时最少元素的个数来保证随机打乱顺序的作用。当出队函数被调用但是队列中元素不够时,出队操作将等待更多的元素入队才会完成。
example_batch, label_batch = tf.train.shuffle_batch(
    [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=30
)

这两个函数除了可以将单个训练数据整理成输入batch,还提供了并行化处理输入数据的方法。通过设置num_threads参数,可以指定多个线程同时执行入队操作。入队操作就是数据读取以及预处理过程。当num_threads大于1时,多个线程会同时读取一个文件中的不同样例并进行预处理。

如果需要多个线程处理不同文件中的样例,可以使用tf.train.batch_jointf.train.shuffle_batch_join函数。此函数会从输入文件队列中获取不同文件分配给不同的线程。一般来说,输入文件队列时通过tf.train.string_input_producer函数生成的,这个函数会平均分配文件以保证不同文件中的数据会尽量平均地使用。

tf.train.shuffle_batch和tf.train.shuffle_batch_join都可以完成多线程并行的方式来进行数据处理,但它们各有优劣。对于shuffle_batch,不同线程会读取同一个文件,如果一个文件中的样例比较相似(比如都属于同一个类别),那么神经网络的训练效果有可能受到影响。所以使用shuffle_batch时,需要尽量将同一个TFRecord文件中的样例随机打乱。而使用shuffle_batch_join时,不同线程会读取不同文件,如果读取数据的线程数比文件数还多,那么多个线程可能会读取同一个文件中相近部分的数据。而且多个线程读取多个文件可能导致过多的硬盘寻址,从而降低读取效率。

3个shuffle:string_input_producer中的shuffle打乱队列中的文件;shuffle_batch中的shuffle打乱队列中的元素;shuffle_batch_join中的shuffle。

7.3.4 输入数据处理框架

事先准备,

把mnist数据集转为10个TFRecord文件,

 1 import tensorflow as tf
 2 import numpy as np
 3 from tensorflow.examples.tutorials.mnist import input_data
 4 import math
 5 
 6 mnist = input_data.read_data_sets('/home/error/MNIST_DATA/', dtype=tf.uint8, one_hot=True)
 7 
 8 images = mnist.train.images
 9 num_examples = mnist.train.num_examples
10 labels = mnist.train.labels
11 pixels = images.shape[1]
12 height = width = int(math.sqrt(pixels))
13 
14 num_shards = 10
15 # 每个文件有多少数据
16 instances_per_shard = int(mnist.train.num_examples / num_shards)  # 5500
17 
18 def _int64_feature(value):
19     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
20 
21 def _bytes_feature(value):
22     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
23 
24 
25 for i in range(num_shards):
26     filename = '/home/error/tfrecord_mnist/data.tfrecords-%.5d-of-%.5d' % (i, num_shards)
27     writer = tf.python_io.TFRecordWriter(filename)
28 
29     for j in range(instances_per_shard * i, instances_per_shard * (i+1)):
30         example = tf.train.Example(features=tf.train.Features(feature={
31             'image': _bytes_feature(images[j].tostring()),  # image[j]为长度为784的一维数组
32             'label': _int64_feature(np.argmax(labels[j])),
33             'height': _int64_feature(height),
34             'width': _int64_feature(width),
35             'channels': _int64_feature(1)
36         }))
37         writer.write(example.SerializeToString())
38     writer.close()

预处理mnist数据遇到的问题:

1). 判断类型

2). channels

33     # 随机翻转
34     distorted_image = tf.image.random_flip_left_right(distorted_image)
35     # 调整色彩
36     distorted_image = distort_color(distorted_image, np.random.randint(2))
37 
38     return distorted_image
39 
40 ######################
41 
42 import tensorflow as tf
43 import matplotlib.pyplot as plt
44 import numpy as np
45 from meng42 import preprocessed_for_train
46 from tensorflow.examples.tutorials.mnist import input_data
47 
48 
49 mnist = input_data.read_data_sets('/home/yangxl/files/mnist/', dtype=tf.uint8, one_hot=True)
50 image = mnist.train.images[4]
51 image = image.reshape([28, 28, 1])
52 
53 # 预处理过程中,`if image.dtype != tf.float32:`报错:TypeError: data type not understood
54 # 原因是image.dtype的类型为numpy, 而tf.float32的类型为tensor, 比较之前必须先统一类型。
55 image = tf.constant(image)
56 
57 # 定义神经网络的输入大小
58 image_size = 28
59 # 预处理
60 distort_image = preprocessed_for_train(image, image_size, image_size, None)
61 distort_image = tf.squeeze(distort_image, axis=2)
62 
63 with tf.Session() as sess:
64     tf.global_variables_initializer().run()
65 
66     distort_image_val = sess.run(distort_image)
67     print(distort_image_val.shape)
68     plt.imshow(distort_image_val)
69     plt.show()

完整示例:

  1 import tensorflow as tf
  2 from meng42 import preprocessed_for_train
  3 import mnist_inference
  4 import os
  5 
  6 
  7 files = tf.train.match_filenames_once(pattern='/home/yangxl/files/mnist_tfrecords/mnist.tfrecords-*')
  8 filename_queue = tf.train.string_input_producer(files, shuffle=False, num_epochs=1)
  9 
 10 reader = tf.TFRecordReader()
 11 _, serialized_example = reader.read(filename_queue)
 12 
 13 features = tf.parse_single_example(serialized_example, features={
 14     'image': tf.FixedLenFeature([], tf.string),
 15     'label': tf.FixedLenFeature([], tf.int64),
 16     'height': tf.FixedLenFeature([], tf.int64),
 17     'width': tf.FixedLenFeature([], tf.int64),
 18     'channels': tf.FixedLenFeature([], tf.int64)
 19 })
 20 
 21 image, label = features['image'], features['label']
 22 height, width = features['height'], features['width']
 23 channels = features['channels']
 24 
 25 decoded_image = tf.decode_raw(image, tf.uint8)  # shape=(?,)
 26 decoded_image = tf.reshape(decoded_image, [28, 28, 1])  # shape=(28, 28, 1)
 27 
 28 # 定义神经网络的输入大小
 29 image_size = 28
 30 # 预处理
 31 distort_image = preprocessed_for_train(decoded_image, image_size, image_size, None)  # shape=(28, 28, ?)
 32 distort_image = tf.reshape(distort_image, [28, 28, 1])  # 预处理过程损坏了shape,会在`shuffle_batch`时报错。
 33 
 34 min_after_dequeue = 1000
 35 batch_size = 100
 36 capacity = min_after_dequeue + 3 * batch_size
 37 image_batch, label_batch = tf.train.shuffle_batch([distort_image, label], batch_size, capacity, min_after_dequeue)
 38 
 39 # 训练
 40 BATCH_SIZE = 100
 41 
 42 LEARNING_RATE_BASE = 0.9
 43 LEARNING_RATE_DECAY = 0.9
 44 REGULARIZATION_RATE = 0.0001  # lambda
 45 TRAINING_STEPS = 20000
 46 MOVING_AVERAGE_DACAY = 0.99
 47 
 48 MODEL_SAVE_PATH = '/home/yangxl/files/save_model2'
 49 MODEL_NAME = 'yangxl.ckpt'
 50 
 51 
 52 def train(image_batch, label_batch):
 53     # 因为从池化层到全连接层要进行reshape,所以不能为shape[0]不能为None。
 54     x = tf.placeholder(tf.float32, [BATCH_SIZE, mnist_inference.IMAGE_SIZE, mnist_inference.IMAGE_SIZE, mnist_inference.NUM_CHANNELS], 'x-input')
 55     y_ = tf.placeholder(tf.int64, [BATCH_SIZE], 'y-input')
 56     # 因为从tfrecords文件中读取的label.shape=(), 所以这里进行了相应调整(y_以及用到y_的节点,测试代码也要对应)。
 57 
 58     # 正则化
 59     regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 60     y = mnist_inference.inference(x, True, regularizer)
 61 
 62     global_step = tf.Variable(0, trainable=False)
 63 
 64     # 滑动平均
 65     variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DACAY, global_step)
 66     variables_averages_op = variables_averages.apply(tf.trainable_variables())
 67 
 68     cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=y_)
 69     cross_entropy_mean = tf.reduce_mean(cross_entropy)
 70     loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
 71 
 72     # learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 10000 / BATCH_SIZE, LEARNING_RATE_DECAY, staircase=True)
 73     learning_rate = 0.01
 74     train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step)
 75     with tf.control_dependencies([train_step, variables_averages_op]):
 76         train_op = tf.no_op(name='train')
 77 
 78     with tf.Session() as sess:
 79         tf.local_variables_initializer().run()
 80         tf.global_variables_initializer().run()
 81 
 82         coord = tf.train.Coordinator()
 83         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 84 
 85         saver = tf.train.Saver()
 86 
 87         image_batch_val, label_batch_val = sess.run([image_batch, label_batch])
 88         for i in range(TRAINING_STEPS):
 89             _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: image_batch_val, y_: label_batch_val})
 90 
 91             if i % 1000 == 0:
 92                 print('after %d training steps, loss on training batch is %g ' % (i, loss_value))
 93                 saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
 94 
 95         coord.request_stop()
 96         coord.join(threads)
 97 
 98 
 99 if __name__ == '__main__':
100     train(image_batch, label_batch)

 

把flower文件转为TFRecord文件,

 1 import tensorflow as tf
 2 import os
 3 import glob
 4 from tensorflow.python.platform import gfile
 5 import numpy as np
 6 
 7 INPUT_DATA = '/home/error/flower_photos'  # 输入文件
 8 
 9 
10 VALIDATION_PERCENTAGE = 10
11 TEST_PERCENTAGE = 10
12 
13 def _int64_feature(value):
14     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
15 
16 def _bytes_feature(value):
17     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
18 
19 def create_image_lists(sess):
20     sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]  # 当前目录和子目录
21     # print(sub_dirs)
22     is_root_dir = True
23 
24     current_labels = 0
25 
26     # 读取所有子目录
27     for sub_dir in sub_dirs:
28         if is_root_dir:  # 把第一个排除了
29             is_root_dir = False
30             continue
31 
32         # 获取一个子目录中所有的图片文件
33         extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
34         file_list = []
35         dir_name = os.path.basename(sub_dir)  # '/'最后面的部分
36         print(dir_name)
37         for extension in extensions:
38             file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
39             file_list.extend(glob.glob(file_glob))  # glob.glob返回一个匹配该模式的列表, glob和os配合使用来操作文件
40         if not file_list:
41             continue
42 
43         OUTPUT_DATA = '/home/error/inception_v3_data/inception_v3_data_' + dir_name + '.tfrecords'  # 输出文件
44         writer = tf.python_io.TFRecordWriter(OUTPUT_DATA)
45 
46         # 处理图片数据
47         for file_name in file_list:
48             print(file_name)
49             image_raw_data = gfile.FastGFile(file_name, 'rb').read()  # 二进制数据
50             image = tf.image.decode_jpeg(image_raw_data)  # tensor, dtype=uint8  333×500×3   色道0~255
51             # if image.dtype != tf.float32:
52             #     image = tf.image.convert_image_dtype(image, dtype=tf.float32)  # 色道值0~1
53             # image = tf.image.resize_images(image, [299, 299])
54             image_value = sess.run(image)  # numpy.ndarray
55             # print(image_value.shape)
56             height, width, channles = image_value.shape
57             label = current_labels
58             example = tf.train.Example(features=tf.train.Features(feature={
59                 'image': _bytes_feature(image_value.tostring()),
60                 'label': _int64_feature(np.argmax(label)),
61                 'height': _int64_feature(height),
62                 'width': _int64_feature(width),
63                 'channels': _int64_feature(channles)
64             }))
65             writer.write(example.SerializeToString())
66         writer.close()
67 
68         current_labels += 1
69 
70 
71 with tf.Session() as sess:
72     create_image_lists(sess)

 

7.4 数据集

除队列外,tensorflow提供了一套更高层的数据处理框架。在新的框架中,每一个数据来源被抽象成一个“数据集”,开发者可以以数据集为基本对象,方便地进行batching、shuffle等操作。推荐使用数据集作为输入数据的首选框架。数据集是tensorflow的核心部件。

7.4.1 数据集的基本使用方法

在数据集框架中,每个数据集代表一个数据来源:数据可能来自一个tensor,一个TFRecord文件,一个文本文件,或者经过sharding的一系列文件等。

由于训练数据通常无法全部写入内存中,从数据集中读取数据时需要使用一个迭代器按顺序进行读取,这点与队列的dequeue()操作和Reader的read()操作类似。与队列相似,数据集也是计算图上的一个节点。

示例,从一个张量创建一个数据集,

 1 # 从数组创建数据集。不同数据来源,需要调用不同的构造方法。
 2 input_data = [1, 2, 3, 4, 5]
 3 dataset = tf.data.Dataset.from_tensor_slices(input_data)
 4 
 5 # 定义一个迭代器用于遍历数据集。因为上面定义的数据集没有使用placeholder作为输入参数,所以可以使用最简单的one_shot_iterator。
 6 iterator = dataset.make_one_shot_iterator()
 7 
 8 x = iterator.get_next()
 9 y = x * x
10 
11 with tf.Session() as sess:
12     for i in range(len(input_data)):
13         print(sess.run([x, y]))

在真实项目中,训练数据通常保存在硬盘文件中。比如在自然语言处理任务中,训练数据通常以每行一条数据的形式存在文本文件中。这时可以用TextLineDataset来构造。

 1 # 从文件创建数据集
 2 # windows中必须要加后缀。'D:\\files\\tf\\firsts.txt'
 3 # 只有一个文件时,可以只传一个字符串格式的文件名。
 4 input_files = ['/home/error/checkpoint', '/home/error/ten']
 5 dataset = tf.data.TextLineDataset(input_files)
 6 
 7 iterator = dataset.make_one_shot_iterator()
 8 
 9 x = iterator.get_next()
10 
11 with tf.Session() as sess:
12     for i in range(20):
13         print(sess.run(x))

在图像相关任务中,训练数据通常以TFRecords形式存储,这时可以用TFRecordDataset来读取数据。与文本文件不同的是,每个tfrecord都有自己不同的feature格式,因此需要提供一个parser函数来解析所读取的tfrecord格式的数据。

 1 # 从TFRecord文件创建数据集
 2 input_files = ['/home/error/tt.tfrecords', '/home/error/tt2.tfrecords']
 3 dataset = tf.data.TFRecordDataset(input_files)
 4 
 5 # map()函数表示对数据集中的每一条数据调用相应的方法。
 6 # TFRecordDataset读出的是二进制数据,需要通过map调用parser来对二进制数据进行解析。
 7 dataset = dataset.map(parser)
 8 
 9 iterator = dataset.make_one_shot_iterator()
10 features = iterator.get_next()
11 
12 with tf.Session() as sess:
13     for i in range(5):  # 不能超过样例个数,否则报错
14         print(sess.run(features['name']))

把上面的实例改成含有占位符的形式:

 1 def parser(record):
 2     features = tf.parse_single_example(
 3         record,
 4         features={
 5             'name': tf.FixedLenFeature([], tf.string),
 6             'image': tf.FixedLenFeature([], tf.string),
 7             'label': tf.FixedLenFeature([], tf.int64),
 8             'height': tf.FixedLenFeature([], tf.int64),
 9             'width': tf.FixedLenFeature([], tf.int64),
10             'channels': tf.FixedLenFeature([], tf.int64)
11         }
12     )
13     return features
14 
15 # 从TFRecord文件创建数据集
16 input_files = tf.placeholder(tf.string)
17 dataset = tf.data.TFRecordDataset(input_files)
18 
19 # map()函数表示对数据集中的每一条数据调用相应的方法。
20 # TFRecordDataset读出的是二进制数据,需要通过map调用parser来对二进制数据进行解析。
21 dataset = dataset.map(parser)
22 
23 iterator = dataset.make_initializable_iterator()
24 features = iterator.get_next()
25 
26 with tf.Session() as sess:
27     sess.run(iterator.initializer, feed_dict={input_files: ['/home/error/tt.tfrecords', '/home/error/tt2.tfrecords']})
28    # 因为不同数据来源的数据量大小难以预知。使用while True可以把所有数据遍历一遍。
29     while True:
30         try:
31             print(sess.run([features['name'], features['height']]))
32         except tf.errors.OutOfRangeError:
33             break

 

7.4.2 数据集的高层操作

dataset = dataset.map(parser)

对数据集中的每一条数据调用参数中指定的parser方法,经过处理后的数据重新组合成一个数据集。

1 distorted_image = preprocess_for_train(
2     decoded_image, image_size, image_size, None
3 )
4 转为
5 dataset = dataset.map(
6     lambda x: preprocess_for_train(x, image_size, image_size, None)
7 )

这样处理的优点是,返回一个新数据集,可以直接继续调用其他高层操作。

在队列框架中,预处理、shuffle、batch等操作有的在队列上进行,有的在图片张量上进行,整个处理流程在处理队列和张量的代码片段中来回切换。而在数据集操作中,所有操作都在数据集上进行。

 

dataset = dataset.shuffle(buffer_size)  # 随机打乱顺序
dataset = dataset.batch(batch_size)  # 将数据组合成batch

shuffle方法中的buffer_size等效于tf.train.shuffle_batch的min_after_dequeue,shuffle算法在内部使用一个缓冲区保存buffer_size条数据,每读入一个新数据时,从这个缓冲区随机选择一条数据进行输出。缓冲区越大,随机性能越好,但占用的内存也越多。

batch方法的batch_size代表要输出的每个batch由多少条数据组成。如果数据集包含多个张量,那么batch操作将对每个张量分开进行。例如,如果数据集中的每个数据是image、label两个张量,其中image的维度是[300, 300],label的维度是[],batch_size是128,那么经过batch操作后的数据集的每个输出将包含两个维度分别为[128, 300, 300]和[128]的张量。

 

dataset = dataset.repeat(N)  # 将数据集重复N份

将数据集重复N份,每一份数据被称为一个epoch。

需要指出的是,如果数据集在repeat之前进行了shuffle操作,输出的每个epoch中随机shuffle的结果并不会相同。因为repeat和map、shuffle、batch等操作一样,都只是计算图上的一个计算节点,repeat只代表重复相同的处理过程,并不会记录前一epoch的处理结果。

其他方法,

dataset.concatenate()  # 将两个数据集顺序连接起来
dataset.take(N)  # 从数据集中读取前N项数据
dataset.skip(N)  # 在数据集中跳过前N项数据
dataset.flat_map()  # 从多个数据集中轮流读取数据

 

与队列框架下的样例不同的是,在训练数据集之外,还另外读取了测试数据集,并对测试集进行了略微不同的预处理。在训练时,调用preprocessed_for_train对图像进行随机反转等预处理操作;而在测试时,测试集以原本的样子直接输入测试。

 1 import tensorflow as tf
 2 from meng42 import preprocessed_for_train
 3 
 4 train_files = tf.train.match_filenames_once('/home/yangxl/files/mnist_tfrecords/mnist.tfrecords-*')
 5 test_files = tf.train.match_filenames_once('/home/yangxl/files/mnist_tfrecords/mnist.tfrecords-0000[49]-of-00010')
 6 
 7 
 8 def parser(record):
 9     features = tf.parse_single_example(
10         record,
11         features={
12             'image': tf.FixedLenFeature([], tf.string),
13             'label': tf.FixedLenFeature([], tf.int64),
14             'height': tf.FixedLenFeature([], tf.int64),
15             'width': tf.FixedLenFeature([], tf.int64),
16             'channels': tf.FixedLenFeature([], tf.int64),
17         }
18     )
19 
20     decoded_image = tf.decode_raw(features['image'], tf.uint8)
21     decoded_image = tf.reshape(decoded_image, [features['height'], features['width'], features['channels']])
22     label = features['label']
23     return decoded_image, label
24 
25 
26 image_size = 28
27 batch_size = 100
28 shuffle_buffer = 1000
29 
30 dataset = tf.data.TFRecordDataset(train_files)
31 dataset = dataset.map(parser)
32 # lambda中的参数image、label, 返回的是一个元组(image, label)
33 dataset = dataset.map(lambda image, label: (preprocessed_for_train(image, image_size, image_size, None), label))
34 dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)
35 # 重复NUM_EPOCHS个epoch。在7.3.4小节中TRAINING_ROUNDS指定了训练轮数,而这里指定了整个数据集重复的次数,这也间接确定了训练的轮数
36 NUM_EPOCHS = 10
37 dataset = dataset.repeat(NUM_EPOCHS)
38 
39 # 虽然定义数据集时没有直接使用placeholder来提供文件地址,但是tf.train.match_filenames_once方法得到的结果与placeholder的机制类似,也需要初始化
40 iterator = dataset.make_initializable_iterator()
41 image_batch, label_batch = iterator.get_next()
42 print(image_batch.shape, label_batch.shape)
43 
44 
45 test_dataset = tf.data.TFRecordDataset(test_files)
46 # 对于测试集,不需要预处理、shuffle、repeat操作,只需用相同的parser进行解析、调整输入层大小、batch即可
47 test_dataset = test_dataset.map(parser)
48 test_dataset = test_dataset.map(lambda image, label: (tf.image.resize_images(image, [image_size, image_size]), label))
49 test_dataset = test_dataset.batch(batch_size)
50 
51 test_iterator = test_dataset.make_initializable_iterator()
52 test_image_batch, test_label_batch = test_iterator.get_next()
53 print(test_image_batch.shape, test_label_batch.shape)
54 
55 with tf.Session() as sess:
56     tf.local_variables_initializer().run()
57     print(test_files.eval())

ok!

posted @ 2018-11-22 17:42  羊小羚  阅读(984)  评论(0编辑  收藏  举报