通过迭代器获取数据
%pylab inline
from keras.datasets import mnist
import mxnet as mx
from mxnet import nd
from mxnet import autograd
import random
from mxnet import gluon
(x_train, y_train), (x_test, y_test) = mnist.load_data()
num_examples = x_train.shape[0]
num_inputs = x_train.shape[1] * x_train.shape[2]
batch_size = 64
1. 自定义数据迭代器
def data_iter1(X, Y, batch_size):
num_samples = X.shape[0]
idx = list(range(num_samples))
random.shuffle(idx)
X = nd.array(X)
Y = nd.array(Y)
for i in range(0, num_examples, batch_size):
j = nd.array(idx[i: min(i + batch_size, num_examples)])
yield nd.take(X, j), nd.take(Y, j)
2. Gluon 迭代器
dataset = gluon.data.ArrayDataset(x_train, y_train)
data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)
3. 从迭代器中获取数据
for data, label in data_iter:
print(data.shape, label.shape)
break
(64, 28, 28) (64,)
for data, label in data_iter1(x_train, y_train, batch_size):
print(data.shape, label.shape)
break
(64, 28, 28) (64,)
更多精彩见:使用 迭代器 获取 Cifar 等常用数据集
探寻有趣之事!