一文讲解TensorFlow数据接口 tf.data.Dataset
导入数据
X = pd.read_csv('./datasets/housing/housing.csv')
X = X.sample(n=10)
X.drop(columns = X.columns.difference(['longitude']), inplace=True)
为了避免报错,先进行格式转换:
X = np.asarray(X).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(X)
for _ in dataset:
print(_)
tf.Tensor([-118.75], shape=(1,), dtype=float32)
tf.Tensor([-119.25], shape=(1,), dtype=float32)
tf.Tensor([-118.18], shape=(1,), dtype=float32)
tf.Tensor([-118.13], shape=(1,), dtype=float32)
tf.Tensor([-118.2], shape=(1,), dtype=float32)
tf.Tensor([-117.25], shape=(1,), dtype=float32)
tf.Tensor([-117.93], shape=(1,), dtype=float32)
tf.Tensor([-122.96], shape=(1,), dtype=float32)
tf.Tensor([-121.77], shape=(1,), dtype=float32)
tf.Tensor([-121.24], shape=(1,), dtype=float32)
dataset = dataset.repeat(3).batch(10)
for _ in dataset:
print(_)
图解:
repeat(3)
将数据集重复3次,batch(10)
每次输出一个包括10个元素的batch。
tf.Tensor(
[[-118.75]
[-119.25]
[-118.18]
[-118.13]
[-118.2 ]
[-117.25]
[-117.93]
[-122.96]
[-121.77]
[-121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[-118.75]
[-119.25]
[-118.18]
[-118.13]
[-118.2 ]
[-117.25]
[-117.93]
[-122.96]
[-121.77]
[-121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[-118.75]
[-119.25]
[-118.18]
[-118.13]
[-118.2 ]
[-117.25]
[-117.93]
[-122.96]
[-121.77]
[-121.24]], shape=(10, 1), dtype=float32)
如果不能刚好等分,例如
dataset = dataset.repeat(3).batch(9)
for _ in dataset:
print(_)
最后一个batch将包含剩下的元素
tf.Tensor(
[[-122.08]
[-121.37]
[-118.32]
[-122.38]
[-122.09]
[-122.1 ]
[-122.27]
[-121.49]
[-120.68]], shape=(9, 1), dtype=float64)
tf.Tensor(
[[-118.2 ]
[-122.08]
[-121.37]
[-118.32]
[-122.38]
[-122.09]
[-122.1 ]
[-122.27]
[-121.49]], shape=(9, 1), dtype=float64)
tf.Tensor(
[[-120.68]
[-118.2 ]
[-122.08]
[-121.37]
[-118.32]
[-122.38]
[-122.09]
[-122.1 ]
[-122.27]], shape=(9, 1), dtype=float64)
tf.Tensor(
[[-121.49]
[-120.68]
[-118.2 ]], shape=(3, 1), dtype=float64)
map
函数
dataset = dataset.map(lambda x: abs(x))
for _ in dataset:
print(_)
tf.Tensor(
[[118.75]
[119.25]
[118.18]
[118.13]
[118.2 ]
[117.25]
[117.93]
[122.96]
[121.77]
[121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[118.75]
[119.25]
[118.18]
[118.13]
[118.2 ]
[117.25]
[117.93]
[122.96]
[121.77]
[121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[118.75]
[119.25]
[118.18]
[118.13]
[118.2 ]
[117.25]
[117.93]
[122.96]
[121.77]
[121.24]], shape=(10, 1), dtype=float32)
filter
函数
使用filter
函数前需要先unbatch
dataset = dataset.unbatch()
dataset = dataset.filter(lambda x: x < 120)