tensorflow TFRecordDataset shuffle 实测

代码

def data_iterator(tfrecords, batch_size=2, shuffle=True, train=True, num_parallel_reads=3):
    # 声明TFRecordDataset
    dataset = tf.data.TFRecordDataset(tfrecords, num_parallel_reads=num_parallel_reads)
    dataset = dataset.map(_parse_function)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    if train:
        dataset = dataset.repeat()

    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    return iterator

说明:

  1. dataset.shuffle(buffer_size)会在batch之间打乱,具体见前面的笔记
  2. num_parallel_reads参数可以并行加载数据,实测可以在batch内部打乱数据。

如果数据制作的时候顺序固定,相似较大,比如按顺序crop的数据得到多个tfrecords,把这两项加上可以较为充分的打乱数据

posted @ 2020-11-13 14:36  wioponsen  阅读(656)  评论(0编辑  收藏  举报