TensorFlow dataset.shuffle、batch、repeat的使用详解
https://www.jb51.net/article/178976.htm
直接看代码例子,有详细注释!!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
import tensorflow as tf import numpy as np d = np.arange( 0 , 60 ).reshape([ 6 , 10 ]) # 将array转化为tensor data = tf.data.Dataset.from_tensor_slices(d) # 从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本 # buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size, # 此时会再次打乱 data = data.shuffle(buffer_size = 3 ) # 每次从buffer中抽取4个样本 data = data.batch( 4 ) # 将data数据集重复,其实就是2个epoch数据集 data = data.repeat( 2 ) # 构造获取数据的迭代器 iters = data.make_one_shot_iterator() # 每次从迭代器中获取一批数据 batch = iters.get_next() sess = tf.Session() sess.run(batch) # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
In [ 21 ]: d Out[ 21 ]: array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]]) In [ 22 ]: sess.run(batch) Out[ 22 ]: array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ]]) In [ 23 ]: sess.run(batch) Out[ 23 ]: array([[ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]]) |
从输出结果可以看出:
shuffle是按顺序将数据放入buffer里面的;
当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。
那么,当repeat函数在shuffle之前会怎么样呢?如下:
1
2
3
4
5
|
data = data.repeat( 2 ) data = data.shuffle(buffer_size = 3 ) data = data.batch( 4 ) |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
In [ 25 ]: sess.run(batch) Out[ 25 ]: array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]]) In [ 26 ]: sess.run(batch) Out[ 26 ]: array([[ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ], [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ]]) In [ 27 ]: sess.run(batch) Out[ 27 ]: array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]]) |
可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
2019-10-24 Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA
2019-10-24 ImportError: libcublas.so.9.0: cannot open shared object file: No such file or directory
2019-10-24 cuda和tensorflow对应关系