MXNet 固定权重

Introduction

最近需要把操作符中的权重固定 (e.g.: convolution_weight)。主要的困难是要与symbol,modlue兼容 (主要是不能给迭代器压力),接口要简洁。
比如下面这段程序就会被更新:

import mxnet as mx
M,N=3,3
num_filter=1
kernel=mx.nd.array([ [1,2,3],[1,2,3],[1,2,3] ])


d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True)
loss=mx.sym.MakeLoss(data=conv1)

mod=mx.mod.Module(symbol=loss,data_names=('data',))
bch_kernel=kernel.reshape((1,1,M,N))
mod.bind(data_shapes=[ ('data',[1,1,M,N]), ])
mod.set_params({'kw':bch_kernel},{})
bch_kernel=kernel.reshape((1,1,M,N))
mod.init_params()                                  #  still be valid even after init_params => compatible to  Module.fit ?
mod.init_optimizer()                               # let me see
mod.forward(mx.io.DataBatch([bch_kernel],[]))
mod.get_outputs()[0].asnumpy()                     # B1
# get array([[[[ 42.]]]], dtype=float32)
mod.backward()
mod.update()

mod.forward(mx.io.DataBatch([bch_kernel],[]))
mod.get_outputs()[0].asnumpy()                     # B2
# get array([[[[ 41.57999802]]]], dtype=float32)

官网上找到的问题似乎不是针对这个需求的(回答说用Slice之类的操作,或者multiply。。。还没想到怎么与module兼容)

Followup

Stage 1

convolution Op的底层实现中,data, weight, bias,是同时作为in_data进行的,所以在symbol中进行无望。
想到了在更新的时候做手脚:

# python/mxnet/module/module.py update() 
#  --->
# python/build/lib.linux-x86_64-2.7/mxnet/model.py

def _update_params(param_arrays, grad_arrays, updater, num_device,                                                             
                   kvstore=None):
    """ Perform update of param_arrays from grad_arrays not on kvstore."""
    for index, pair in enumerate(zip(param_arrays, grad_arrays)):
        arg_list, grad_list = pair
        if grad_list[0] is None:
            continue
        if kvstore:
            # push gradient, priority is negative index
            kvstore.push(index, grad_list, priority=-index)
            # pull back the sum gradients, to the same locations.
            kvstore.pull(index, grad_list, priority=-index)
        for k, p in enumerate(zip(arg_list, grad_list)):
            # faked an index here, to make optimizer create diff
            # state for the same index but on diff devs, TODO(mli)
            # use a better solution latter
            w, g = p 
            updater(index*num_device+k, g, w)

看起来只要把grad_list置零就可。但这样可能有些繁杂(除了查表,还要考虑底层支持),但好像也没办法了,到空间分配的地方去看看 (地址分配的环节位于bind中):

# python/mxnet/module/module.py
...
    def bind(self, data_shapes, label_shapes=None, for_training=True,
             inputs_need_grad=False, force_rebind=False, shared_module=None,
             grad_req='write'):
		...
        self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,                                              
                                                     self._work_load_list, self._data_shapes,
                                                     self._label_shapes, self._param_names,
                                                     for_training, inputs_need_grad,
                                                     shared_group, logger=self.logger,
                                                     fixed_param_names=self._fixed_param_names,
                                                     grad_req=grad_req, input_types=input_types)
		...

突然发现了fixed_param_names,这是个好预兆。

# python/build/lib.linux-x86_64-2.7/mxnet/module/executor_group.py
...
class DataParallelExecutorGroup(object):
""
...
fixed_param_names: list of str
        Indicate parameters to be fixed during training. Parameters in this list will not allocate
        space for gradient, nor do gradient calculation.
"""

Stage 2

注释说得很明确了,真是nice。顺便来看看怎么实现的:

# python/build/lib.linux-x86_64-2.7/mxnet/module/executor_group.py
...
class DataParallelExecutorGroup(object):
    def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
                 for_training, inputs_need_grad, shared_group=None, input_types=None,
                 logger=logging, fixed_param_names=None, grad_req='write'):
            ...
            for k in self.arg_names:
                if k in self.param_names:
                    self.grad_req[k] = 'null' if k in self.fixed_param_names else grad_req
                elif k in data_names:
                    self.grad_req[k] = grad_req if self.inputs_need_grad else 'null'
                else:
                    self.grad_req[k] = 'null'
...

从该程序段和后续程序的命名来看,grad_req指示了空间分配的实施。

Solution

所以,程序如下:

import mxnet as mx
M,N=3,3
num_filter=1
kernel=mx.nd.array([ [1,2,3],[1,2,3],[1,2,3] ])


d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True)
loss=mx.sym.MakeLoss(data=conv1)

mod=mx.mod.Module(symbol=loss,data_names=('data',),fixed_param_names=('kw'))
bch_kernel=kernel.reshape((1,1,M,N))
mod.bind(data_shapes=[ ('data',[1,1,M,N]), ])
mod.set_params({'kw':bch_kernel},{})
bch_kernel=kernel.reshape((1,1,M,N))
mod.init_params()       
mod.init_optimizer()     
     
mod.forward(mx.io.DataBatch([bch_kernel],[]))
mod.get_outputs()[0].asnumpy()                     # B1
# array([[[[ 42.]]]], dtype=float32)
mod.backward()
mod.update()

mod.forward(mx.io.DataBatch([bch_kernel],[]))
mod.get_outputs()[0].asnumpy()                     # B2
# array([[[[ 42.]]]], dtype=float32)

good 😃

posted @ 2017-04-28 09:05  rotxin  阅读(1452)  评论(0编辑  收藏  举报