MXNet 中symbol 绑定

于看到MXNet 的doc 了。今天准备做些GAN的试验需要手工hack些步骤,遇到要绑定,想起毕业设计时在julia中闷着拼差点又要热血上涌。。。。(罪过)。还好,在example 里面发现了一些可参考的,但doc 中关于bind 的部分有些残缺,似乎只介绍了关于executor的接口:

A = mx.Variable(:A)
B = mx.Variable(:B)
C = A .* B
a = mx.ones(3) * 4
b = mx.ones(3) * 2
c_exec = mx.bind(C, context=mx.cpu(), args=Dict(:A => a, :B => b))
mx.forward(c_exec)
copy(c_exec.outputs[1]) # copy turns NDArray into Julia Array
# =>
# 3-element Array{Float32,1}:
# 8.0
# 8.0
# 8.0

example中的程序更实在些,稍微改了下,记到这里:

import mxnet as mx
import numpy as np

M,N=10,20
device=mx.cpu()

data=mx.sym.Variable('data')
label=mx.sym.Variable('label')
conv1=mx.sym.Convolution(data=data,kernel=(3,3),num_filter=2)
flatten=mx.sym.Flatten(data=conv1)
fc1=mx.sym.FullyConnected(data=flatten,num_hidden=1)
loss_data=fc1# flatten
loss=mx.sym.LogisticRegressionOutput(data=loss_data,label=label)
img=np.zeros((M,N)).reshape((1,1,M,N))
gdt=[1,]

img=mx.nd.array(img)
gdt=mx.nd.array(gdt)


mod=mx.module.Module(symbol=loss,data_names=('data',),label_names=('label',))
#D={'data':img,'label':gdt}
mod.bind(data_shapes=[('data',(1,1,M,N))],label_shapes=[('label',(1,))],inputs_need_grad=True)
mod.init_params()
mod.init_optimizer(optimizer='adam')
mod.forward(mx.io.DataBatch([img],[gdt]),is_train=True)
out=mod.get_outputs()
posted @ 2017-04-28 10:11  rotxin  阅读(811)  评论(0编辑  收藏  举报