Tensorflow学习笔记No.3
使用tf.data加载数据
tf.data是tensorflow2.0中加入的数据加载模块,是一个非常便捷的处理数据的模块。
这里简单介绍一些tf.data的使用方法。
1.加载tensorflow中自带的mnist数据并对数据进行一些简单的处理
1 (train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data() 2 train_image = train_image / 255 3 test_image = test_image / 255
2.使用tf.data.Dataset.from_tensor_slices()方法对数据进行切片处理
该函数是dataset核心函数之一,它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外维度的数据切开,分成一组一组的。
1 ds_train_label = tf.data.Dataset.from_tensor_slices(train_label) 2 ds_train_label = tf.data.Dataset.from_tensor_slices(train_label)
3.使用tf.data.Dataset.zip()方法将image和label数据合并
tf.data.Dataset.zip()方法可将迭代对象中相对应(例如image对应label)的数据打包成一个元组,返回由这些元组组成的对象。
1 ds_train = tf.data.Dataset.zip((ds_train_image, ds_train_label))
这里ds_train中的数据就是由许多个(image, label)元组组成的。
事实上我们也可以直接把train_image与train_label进行合并,以元组的形式对train_image和train_label进行切片即可。
1 ds_trian = tf.data.Dataset.from_tensor_slices((train_image, train_label))
4.使用.shuffle().repeat().batch()方法对数据进行处理
1 ds_train = ds_train.shuffle(10000).repeat(count = 3).batch(64)
.shuffle()作用是将数据进行打乱操作,传入参数为buffer_size,改参数为设置“打乱缓存区大小”,也就是说程序会维持一个buffer_size大小的缓存,每次都会随机在这个缓存区抽取一定数量的数据。
.repeat()作用就是将数据重复使用多少次,参数是重复的次数,若无参数则无限重复。
.batch()作用是将数据打包成batch_size, 每batch_size个数据打包在一起作为一个epoch。
5.注意事项
在使用tf.data时,如果不设置数据的.repeat()的重复次数,数据会无限制重复,如果把这样的数据直接输入到神经网络中会导致内存不足程序无法终止等错误。此时,要在.fit()方法中加以限制。
1 history = model.fit(ds_train, epochs = 5, steps_per_epoch = step_per_epochs, 2 validation_data = ds_test, validation_steps = 10000 // 64 3 )
使用steps_per_epoch参数限制每个epochs的数据量。
使用validation_steps限制验证集中的数据量。
到这里tf.data的简单介绍就结束了,后续会更新tf.data中的更多内容。