tensorflow(十七):数据的加载:map()、shuffle()、tf.data.Dataset.from_tensor_slices()

一、数据集简介

 

 

 

 

二、MNIST数据集介绍

 

 三、CIFAR 10/100数据集介绍

 

 

 

 四、tf.data.Dataset.from_tensor_slices()

 

 五、shuffle()随机打散

 

 六、map()数据预处理

 

 

 

 

 

 

 七、实战

复制代码
import tensorflow as tf
import tensorflow.keras as keras
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def prepare_mnist_features_and_labels(x,y):
    x = tf.cast(x, tf.float32) / 255.0
    y = tf.cast(y, tf.int64)
    return x,y

def mnist_dataset():
    (x,y), (x_test,y_test) = keras.datasets.fashion_mnist.load_data() #numpy中的格式

    y = tf.one_hot(y, depth=10)                     #[10k] ==> [10k,10]的tensor
    y_test = tf.one_hot(y_test, depth=10)

    ds = tf.data.Dataset.from_tensor_slices((x,y))
    ds = ds.map(prepare_mnist_features_and_labels)  #数据预处理,注意:tf.map中传进的参数
    ds = ds.shuffle(60000).batch(100)               #随机打散,读取一个batch的样本

    ds_val = tf.data.Dataset.from_tensor_slices((x_test,y_test))
    ds_val = ds_val.map(prepare_mnist_features_and_labels)
    ds_val = ds_val.shuffle(10000).batch(100)
    return ds, ds_val


def main():
    ds, ds_val = mnist_dataset()

    print("训练集信息如下:")
    iteration_ds = iter(ds)
    iter_ds = next(iteration_ds)
    print(iter_ds[0].shape, iter_ds[1].shape)

    print("测试集信息如下:")
    iteration_ds_val = iter(ds_val)
    iter_ds_val = next(iteration_ds_val)
    print(iter_ds_val[0].shape, iter_ds_val[1].shape)

if __name__ == '__main__':
    main()
复制代码

 

 

posted @   jasonzhangxianrong  阅读(1310)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
点击右上角即可分享
微信分享提示