MXNet中bucket机制注记
Preface
之前看API以为bucket
是一个根植于底层操作的接口(MXNet doc
功不可没 -_-|| )。从LSTM
看过来,接触到了一些相关的程序,后面再把bucketing_module.py
那部分查看了下,发现bucket只是一个应用层机制,主要的实现存在于module/bucketing_module.py里面。原理清晰,实现简洁,在这做个记号。
Code & Comments
先放些相关的链接,做个预备。
- MXNet 官方的文档(\tucao 出个文档真不容易,还带时效性...)
- 大神的blog阐述,鞭辟入里
- 之前关于LSTM的blog
鉴于大神已经在这篇[blog]里面说得生动透彻了,这里就能省就省,然后说些大神没功夫顾及的细节。
另外考虑到MXNet的链接经常表现出不靠谱的症状(\kuxia),归结一下1
中有些用的结论:要使用bucket机制,初始化Module时传入的symbol应该是一个函数,这个函数在被调用时将被传入迭代器中的bucket_key参数
。
从调用路径的顺序来走一遍把。
在fit
里面经过bind
,init
等操作,后面会调用prepare
对预取出的数据(如果有)进行准备:
# module/bucketing_module.py
def prepare(self, data_batch):
"""Prepares a data batch for forward.
Parameters
----------
data_batch : DataBatch
"""
# perform bind if haven't done so
assert self.binded and self.params_initialized
bucket_key = data_batch.bucket_key
original_bucket_key = self._curr_bucket_key
data_shapes = data_batch.provide_data
label_shapes = data_batch.provide_label
self.switch_bucket(bucket_key, data_shapes, label_shapes)
# switch back
self.switch_bucket(original_bucket_key, None, None)
显然,switch_bucket
就是负责进行重新绑定的:
# module/bucketing_module.py
def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
assert self.binded, 'call bind before switching bucket'
if not bucket_key in self._buckets: # check if there is already...
symbol, data_names, label_names = self._sym_gen(bucket_key)
module = Module(symbol, data_names, label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
self._buckets[bucket_key] = module
self._curr_module = self._buckets[bucket_key]
self._curr_bucket_key = bucket_key
逻辑很明白,_curr_module
里面放了众多的module,这些module的参数全都指向同一组。如果出入的bucket_key
没有出现过,就bind一个并放入_curr_module列表里面去;如果已经有了(包括刚刚bind出来的),就切换到那个module上。
Misc
其他有一些相关的材料顺带放在这。
- 上一篇blog里面推测bucket机制可能会对补齐的那部分进行处理,这一点与
io.py
里面的DataBatch
中pad
变量有些联系。在module/base_module.py中,查找pad的引用,发现和io.py里面的注释一致,只在prediction的时候进行了使用,训练的时候被忽视。 exmple/rnn/bucketing
里面有更高层接口的使用示例。