『MXNet』第四弹_Gluon自定义层
一、不含参数层
通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward
函数里,
from mxnet import nd, gluon from mxnet.gluon import nn class CenteredLayer(nn.Block): def __init__(self, **kwargs): super(CenteredLayer, self).__init__(**kwargs) def forward(self, x): return x - x.mean() # 直接使用这个层 layer = CenteredLayer() # layer(nd.array([1, 2, 3, 4, 5])) # 构建更复杂模型 net = nn.Sequential() net.add(nn.Dense(128)) net.add(nn.Dense(10)) net.add(CenteredLayer()) # 初始化、运行…… net.initialize() y = net(nd.random.uniform(shape=(4, 8)))
二、含参数层
注意,本节实现的自定义层不能自动推断输入尺寸,需要手动指定
见上节『MXNet』第三弹_Gluon模型参数在自定义层的时候我们常使用Block自带的ParameterDict类添加成员变量params,如下,
from mxnet import gluon from mxnet.gluon import nn class MyDense(nn.Block): def __init__(self, units, in_units, **kwargs): super(MyDense, self).__init__(**kwargs) self.weight = self.params.get('weight', shape=(in_units, units)) self.bias = self.params.get('bias', shape=(units,)) def forward(self, x): linear = nd.dot(x, self.weight.data()) + self.bias.data() return nd.relu(linear) # 实际运行 dense = MyDense(5, in_units=10)
如果不想使用ParameterDict类则需要一下操作
# self.weight = self.params.get('weight', shape=(in_units, units)) self.weight = gluon.Parameter('weight', shape=(in_units, units)) self.params.update({'weight':self.weight})
否则在net.initialize()初始化时是初始化不到ParameterDict外变量的。
有关这一点详见下面:
def __init__(self, conv_arch, dropout_keep_prob, **kwargs): super(SSD, self).__init__(**kwargs) self.vgg_conv = nn.Sequential() self.vgg_conv.add(repeat(*conv_arch[0], pool=False)) [self.vgg_conv.add(repeat(*conv_arch[i])) for i in range(1, len(conv_arch))] # 迭代器对象只能进行单次迭代,所以将之转化为tuple,否则识别参数处迭代后forward再次迭代直接跳出循环 # self.vgg_conv = tuple([repeat(*conv_arch[i]) # for i in range(len(conv_arch))]) # 只能识别实例属性直接为mx层函数或者mx序列对象的参数,如果使用其他容器,需要将参数收集进参数字典 # _ = [self.params.update(block.collect_params()) for block in self.vgg_conv] def forward(self, x, feat_layers): end_points = {'block0': x} for (index, block) in enumerate(self.vgg_conv): end_points.update({'block{:d}'.format(index+1): block(end_points['block{:d}'.format(index)])}) return end_points
属性对象是mxnet的对象时才能默认识别层中的参数,否则需要显式收集进self.params中。
测试代码:
if __name__ == '__main__': ssd = SSD(conv_arch=((2, 64), (2, 128), (3, 256), (3, 512), (3, 512)), dropout_keep_prob=0.5) ssd.initialize() X = mx.ndarray.random.uniform(shape=(1, 1, 304, 304)) import pprint as pp pp.pprint([x[1].shape for x in ssd(X).items()])
自行验证即可。