如何用加载本地数据库为tf.data.Dataset格式
云端的数据库存储在google的服务器,所以无法通过tfds.load('mnist', split='train')这样的方式直接从云端读取,而且tfds.load('mnist',split='train',data_dir='...')也无法实现本地加载。
下面简单描述如何从本地加载数据
一、MNIST数据库
1. 通过gzip模块打开本地的train-images-idx3-ubyte.gz文件为numpy数据
2. 通过tf.data.Dataset.from_tensor_slices读取numpy数据
3.将image数据和label数据打包到一起,并分别打包的数据转换成字典
代码如下:(tf版本为2.12.0)
import gzip import numpy as np import tensorflow as tf # 解压并读取图像数据为 numpy array # 注意:根据你的idx3-ubyte文件的实际数据维度reshape,如下是28x28图像的例子 with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f: images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28 * 28) # 解压并读取标签数据为 numpy array with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f: labels = np.frombuffer(f.read(), np.uint8, offset=8) # 转换 numpy array 到TF数据集对象 dataset_images = tf.data.Dataset.from_tensor_slices(images) dataset_labels = tf.data.Dataset.from_tensor_slices(labels) # 组合图像和标签数据集到一起并变成字典形式 dataset = tf.data.Dataset.zip((dataset_images, dataset_labels)) dataset = dataset.map(lambda x, y: {"image": x, "label": y}) # 现在,你可以像之前示例代码一样处理dataset数据集: for example in dataset.take(1): image, label = example['image'], example['label']