tensorflow的data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()理解

tensorflow的data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()理解

batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size
dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合
dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
#源数据集
[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]
 [ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]

# 通过shuffle batch后取得的样本
[[ 0.4236548   0.64589411]
 [ 0.60276338  0.54488318]
 [ 0.43758721  0.891773  ]
 [ 0.5488135   0.71518937]]
[[ 0.96366276  0.38344152]
 [ 0.56804456  0.92559664]
 [ 0.0202184   0.83261985]
 [ 0.79172504  0.52889492]]
[[ 0.07103606  0.0871293 ]
 [ 0.97861834  0.79915856]
 [ 0.77815675  0.87001215]]  #最后一个batch样本个数为3
[[ 0.60276338  0.54488318]
 [ 0.5488135   0.71518937]
 [ 0.43758721  0.891773  ]
 [ 0.79172504  0.52889492]]
[[ 0.4236548   0.64589411]
 [ 0.56804456  0.92559664]
 [ 0.0202184   0.83261985]
 [ 0.07103606  0.0871293 ]]
[[ 0.77815675  0.87001215]
 [ 0.96366276  0.38344152]
 [ 0.97861834  0.79915856]] #最后一个batch样本个数为

1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))

[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]
 [ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]

[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]]
[[ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]]
[[ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]
[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]]
[[ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]]
[[ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834 

注意如果repeat在shuffle之前使用:
官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))

[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]
 [ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]

[[ 0.56804456  0.92559664]
 [ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.07103606  0.0871293 ]]
[[ 0.96366276  0.38344152]
 [ 0.43758721  0.891773  ]
 [ 0.43758721  0.891773  ]
 [ 0.77815675  0.87001215]]
[[ 0.79172504  0.52889492]   #出现相同样本出现在同一个batch中
 [ 0.79172504  0.52889492]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]]
[[ 0.07103606  0.0871293 ]
 [ 0.4236548   0.64589411]
 [ 0.96366276  0.38344152]
 [ 0.5488135   0.71518937]]
[[ 0.97861834  0.79915856]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.56804456  0.92559664]]
[[ 0.0202184   0.83261985]
 [ 0.97861834  0.79915856]]          #可以看到最后个batch为2,而前面都是4    

使用案例:
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    print('Parsing', filenames)
    def decode_libsvm(line):
        #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
        #features = dict(zip(CSV_COLUMNS, columns))
        #labels = features.pop(LABEL_COLUMN)
        columns = tf.string_split([line], ' ')
        labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
        splits = tf.string_split(columns.values[1:], ':')
        id_vals = tf.reshape(splits.values,splits.dense_shape)
        feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
        feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
        feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
        #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
        #for i in range(splits.dense_shape.eval()[0]):
        #    feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
        #    feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
        #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
        return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    #return dataset.make_one_shot_iterator()
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
    return batch_features,batch_labels

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/qq_16234613/article/details/81703228

回答者Houtarou Oreki:

比如:你将会看到每个shuffle程序将会从dataset中随机生成大小等于buffer size的样本。

import tensorflow as  tf
dataset = tf.data.Dataset.from_tensor_slices([0,1,2,3,4,5,6,7,8,9])
dataset=dataset.shuffle(buffer_size=2)
dataset = dataset.batch(batch_size=1)
iterator = dataset.make_initializable_iterator()
next_element=iterator.get_next()
init_op = iterator.initializer
with tf.Session() as sess:
    sess.run(init_op)
    for i in range(10):
        print(sess.run(next_element))

我得到了以下输出:

[1]
[0]
[3]
[2]
[4]
[5]
[7]
[8]
[9]
[6]

buffer背后的关键idea是,在memory中总是keep着buffer_size个元素。一旦你从buffer中随机地得到了一个 sample(batch),你会把下一个batch的元素放进buffer,再次从新buffer中sample。

 buffer:0,1, get a sample  [1]
 buffer:0,2, get a sample  [0]
 buffer:2,3, get a sample  [3]
 buffer:2,4, get a sample  [2]
 buffer:4,5, get a sample  [4]
 buffer:5,6, get a sample  [5]
 buffer:6,7, get a sample  [7]
 buffer:6,8, get a sample  [8]
 buffer:6,9, get a sample  [9]
 buffer:6    get a sample  [6]

本文作者:薄书

本文链接:https://www.cnblogs.com/aimoboshu/p/14567332.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   薄书  阅读(787)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
展开