tensorflow学习笔记--dataset使用,创建自己的数据集

数据读入需求

我们在训练模型参数时想要从训练数据集中一次取出一小批数据(比如50条、100条)做梯度下降,不断地分批取出数据直到损失函数基本不再减小并且在训练集上的正确率足够高,取出的n条数据还要是预处理过的,一次取出的要包含输入数据和对应的lable,并且希望在达到训练效果之前可以不断地取出数据而不会因数据集取空了提前结束训练,最好取出的数据还是乱序的。

基于上面的要求,我们可以利用TensorFlow的dataset模块创建我们所需的数据集。

Dataset简介

TensorFlow程序数据导入的方法有多种。一是通过 feed_dict 传入具体值。二是利用tf的Queues创建数据队列,一次取出batch个数据进行训练,队列可以用多线程读数据,速度比较快,但是队列模块的用法比较复杂,要修改程序的时候就感觉很乱。

Dataset与队列相比就简单多了,Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道。

dataset用法

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))

创建了一个dataset,这个dataset中含有5个元素1….,5,为了将5个元素取出,方法是从Dataset中示例化一个iterator,然后对iterator进行迭代。

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        print(sess.run(one_element))    

语句iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素。这里取5次后dataset里的元素就空了,再取的话就就会抛出tf.errors.OutOfRangeError异常。

除了one-hot iterator,tf还支持其他三种iterator

  • initializable
  • reinitializable
  • feedable

这三个迭代器比one-hot复杂,这里就不介绍他们了。

 

dataset元素变换

dataset数据集API还有一些操作元素的函数来满足我们的对输入数据的需求。

  • map
  • shuffle
  • batch
  • repeat

1. map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:

def add1(x):
    return x+1

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) dataset = dataset.map(add1)

2. shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(500)

3. batch

使用一次iterator返回一批数据的数量:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess: for i in range(10): print(sess.run(one_element)) # 这样就一次获取两个数,可以取3次,第三次取到一个数

4. repeat

上面的代码取3次数就取完了,再取得话就会抛出异常,如果想重复取数,可以用dataset.repeat(count),count的值表示将全部的数在dataset中重复几次:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2).repeat(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.next()
with tf.Session() as sess:
    for i in range(10):
        print(sess.run(one_element))

这样就将5个数重复了两遍。这里需要注意的一点是它虽然重复了两次,但并不是可以取5次,一次取两个数,而是:[1,2], [3,4] , [5],  [1,2], [3,4] , [5] 。这样再取到数据集末尾的时候得到的数据数量不是我们设置的batch_size 条数据。要想重复取数并且每次得到的都是batch_size条数据,可以设置batch_size的大小能被总数据量整除。

repeat()中的参数如果是None,则可以无限取数。

 

读入图片和lable,创建自己的数据集

import tensorflow as tf
import os

batch_size = 50
img_resize = [100,100]
epoch_num = None   # dataset.repeat() 的参数,设置为None,可以不断取数

# 传入图片名,返回正则化后的图片的像素值
def read_img(img_name, lable): image = tf.read_file(img_name) image = tf.image.decode_jpeg(image) image = tf.image.resize_images(image, img_resize) image = tf.image.per_image_standardization(image) return image,lable
# 传入图片所在的文件夹,图片名含有图片的lable,返回利用文件夹中图片创建的dataset
def create_dataset(path): files = os.listdir(path) # 列出文件夹中所有的图片 img_names = [] lables = [] for f in files: img_names.append(os.path.join(path,f)) # 图片的完整路径append到文件名list中 lable = f.split('.')[0] lables.append([int(i) for i in lable]) # 根据规则得到图片的lable img_names = tf.convert_to_tensor(img_names, dtype=tf.string) lables = tf.convert_to_tensor(lables, dtype=tf.float32) # 将图片名list和lable的list转换成Tensor类型
dataset
= tf.data.Dataset.from_tensor_slices((img_names,lables)) # 创建dataset,传入的需要是tensor类型
dataset
= dataset.map(read_img) # 传入read_img函数,将图片名转为像素
  
  # 将dataset打乱,设置一次获取batch_size条数据 dataset
= dataset.shuffle(buffer_size=800).batch(batch_size).repeat(epoch_num)
return dataset
dataset
= create_dataset('./img') # 图片所在的路径为./img iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() # 创建dataset是batch_size 为多少这里一次就能获取多少个数据

在程序中,sess.run(one_element) 一次就能获取到batch_size条数据和对应的lable

 

参考链接

https://blog.csdn.net/ssmixi/article/details/80572813

https://www.jianshu.com/p/d80ea5d73446

posted @ 2020-02-25 16:49  panday  阅读(6904)  评论(0编辑  收藏  举报