MXNet官网案例分析--Train MLP on MNIST
本文是MXNet的官网案例: Train MLP on MNIST. MXNet所有的模块如下图所示:
第一步: 准备数据
从下面程序可以看出,MXNet里面的数据是一个4维NDArray.
import mxnet as mx # mxnet.io.MXDataIter, shape=(128,1,28,28) train = mx.io.MNISTIter( image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-images-idx3-ubyte', label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-labels-idx1-ubyte', batch_size = 128, data_shape = (784, ) ) # mxnet.io.MXDataIter, shape=(128,1,28,28) val = mx.io.MNISTIter( image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-images-idx3-ubyte', label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-labels-idx1-ubyte', batch_size = 128, data_shape = (784, ) )
Second: 符号式编程, 生成一个两层的MLP
# Declare a two-layer MLP data = mx.symbol.Variable('data') # data layer fc1 = mx.symbol.FullyConnected(data=data, num_hidden=128) # full connected layer 1 act1 = mx.symbol.Activation(data=fc1, act_type="relu") # activation layer(relu activation function) fc2 = mx.symbol.FullyConnected(data=act1, num_hidden=64) act2 = mx.symbol.Activation(data=fc2, act_type="relu") fc3 = mx.symbol.FullyConnected(data=act2, num_hidden=10) mlp = mx.symbol.SoftmaxOutput(data=fc3, name="softmax") # Softmax layer
一个CNN网络最基本的几层:
输入层: mx.symbol.Variable()
激活层: mx.symbol.Activation()
Batch正则化: mx.symbol.BatchNorm()
Dropout: mx.symbol.Dropout()
全连接层: mx.symbol.FullyConnected()
池化层: mx.symbol.Pooling()
卷积层: mx.symbol.Convolution()
Softmax输出: mx.symbol.SoftmaxOutput()
LRN: mx.symbol.LRN()
......
mx.symbol.FullyConnected(*args, **kwargs)
功能: 对input作矩阵乘法, 并且加上一个偏置. 将shape为(batch_size, input_dim)的input变成(batch_size, num_hidden)的输出;
输入参数:
- data: Symbol类型, 输入数据;
- weight: Symbol类型, 权重矩阵;
- bias: Symbol类型, 偏置参数;
- num_hidden: int型, 必要参数, 隐层节点的数目;
- no_bias: 布尔型, 可选参数, defalut=False, 表示是否不要偏置参数
- name: 字符串类型, 可选参数, 计算结果symbol的名称;
输出参数:
- 输出是一个Symbol: the result symbol
Last: 训练以及测试
# Type: mxnet.model.FeedForward # Train a model on the data model = mx.model.FeedForward( symbol = mlp, num_epoch = 20, learning_rate = .1 ) model.fit(X = train, eval_data = val) # Predict model.predict(X = train)
class mxnet.model.FeedForward(sklearn.base.BaseEstimator)
输入参数:
- symbol: Symbol类型, 网络的symbol结构配置;
- ctx:
- num_epoch: int型, 可选参数,是一个训练参数, 训练的迭代次数;
- epoch_size: 一次epoch使用的batches数目, 默认情况下为(num_train_examples / batch_size)
- optimizer:q
- initializer:
- numpy_batch_size:
- ......
图2 mxnet.model函数列表