1. 程式人生 > >MXNet中bucket機制註記

MXNet中bucket機制註記

.org sse shape 沒有 sta ams origin done org

Preface

之前看API以為bucket是一個根植於底層操作的接口(MXNet doc功不可沒 -_-|| )。從LSTM看過來,接觸到了一些相關的程序,後面再把bucketing_module.py那部分查看了下,發現bucket只是一個應用層機制,主要的實現存在於module/bucketing_module.py裏面。原理清晰,實現簡潔,在這做個記號。

Code & Comments

先放些相關的鏈接,做個預備。

  1. MXNet 官方的文檔(\tucao 出個文檔真不容易,還帶時效性...)
  2. 大神的blog闡述,鞭辟入裏
  3. 之前關於LSTM的blog
    鑒於大神已經在這篇[blog]裏面說得生動透徹了,這裏就能省就省,然後說些大神沒功夫顧及的細節。
    另外考慮到MXNet的鏈接經常表現出不靠譜的癥狀(\kuxia),歸結一下1
    中有些用的結論:要使用bucket機制,初始化Module時傳入的symbol應該是一個函數,這個函數在被調用時將被傳入叠代器中的bucket_key參數

從調用路徑的順序來走一遍把。
fit裏面經過bind,init等操作,後面會調用prepare對預取出的數據(如果有)進行準備:

# module/bucketing_module.py
    def prepare(self, data_batch):
        """Prepares a data batch for forward.

        Parameters
        ----------
        data_batch : DataBatch
""" # perform bind if haven‘t done so assert self.binded and self.params_initialized bucket_key = data_batch.bucket_key original_bucket_key = self._curr_bucket_key data_shapes = data_batch.provide_data label_shapes = data_batch.provide_label self
.switch_bucket(bucket_key, data_shapes, label_shapes) # switch back self.switch_bucket(original_bucket_key, None, None)

顯然,switch_bucket就是負責進行重新綁定的:

# module/bucketing_module.py
    def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
         assert self.binded, ‘call bind before switching bucket‘
        if not bucket_key in self._buckets:    # check if there is already...
            symbol, data_names, label_names = self._sym_gen(bucket_key)
            module = Module(symbol, data_names, label_names,
                            logger=self.logger, context=self._context,
                            work_load_list=self._work_load_list,
                            fixed_param_names=self._fixed_param_names,
                            state_names=self._state_names)
            module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                        self._curr_module.inputs_need_grad,
                        force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
            self._buckets[bucket_key] = module

        self._curr_module = self._buckets[bucket_key]
        self._curr_bucket_key = bucket_key

邏輯很明白,_curr_module裏面放了眾多的module,這些module的參數全都指向同一組。如果出入的bucket_key沒有出現過,就bind一個並放入*_curr_module列表裏面去;如果已經有了(包括剛剛bind出來的),就切換到那個module*上。

Misc

其他有一些相關的材料順帶放在這。

  1. 上一篇blog裏面推測bucket機制可能會對補齊的那部分進行處理,這一點與io.py裏面的DataBatchpad變量有些聯系。在module/base_module.py中,查找pad的引用,發現和io.py裏面的註釋一致,只在prediction的時候進行了使用,訓練的時候被忽視。
  2. exmple/rnn/bucketing裏面有更高層接口的使用示例。

MXNet中bucket機制註記