tf.data.Dataset.from_tensor_slices 的用法

tf.data.Dataset.from_tensor_slices

  该函数是dataset核心函数之一,它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外维度的数据切开,分成一组一组的。

  假设我们现在有两组数据,分别是特征和标签,为了简化说明问题,我们假设每两个特征对应一个标签。之后把特征和标签组合成一个tuple,那么我们的想法是让每个标签都恰好对应2个特征,而且像直接切片,比如:[f11, f12] [t1]。f11表示第一个数据的第一个特征,f12表示第1个数据的第二个特征,t1表示第一个数据标签。那么tf.data.Dataset.from_tensor_slices就是做了这件事情:

Example:

import tensorflow as tf
import numpy as np
 
features, labels = (np.random.sample((6, 3)),  # 模拟6组数据,每组数据3个特征
                    np.random.sample((6, 1)))  # 模拟6组数据,每组数据对应一个标签,注意两者的维数必须匹配
 
print((features, labels))  #  输出下组合的数据
data = tf.data.Dataset.from_tensor_slices((features, labels))
print(data)  # 输出张量的信息

 

posted @ 2022-02-26 11:39  图神经网络  阅读(640)  评论(0编辑  收藏  举报
Live2D