1. 程式人生 > >mxnet 解析模型的引數

mxnet 解析模型的引數

這段時間的工作一直在圍繞移動端展開。專案需要在手機上跑一個深度學習模型,所以直接上resnet或者是其他的比較好的重量級的模型是不現實的,甚至mobilenet都是不太理想的,我用的一個千元機,跑mobilenet幾乎需要1秒的時間。所以網路壓縮是一個必然的選擇。而進行該部分的工作,對模型引數的解析是個前提工作,下面來看如何讀取模型的引數,並對其修改。
對於mxnet來說還是比較方便的:
我們以mobilenet舉例

import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('mobilenet_v2'
, 0) for param in arg_params: print(param) ###列印conv1的卷積權重,當然可以按照需求對引數進行修改 print(arg_params['conv1_weight'])

上述程式碼會打印出讀取的模型引數的名稱, arg_params是一個字典結構 {‘引數名’:引數}
於是便可以輕鬆的對網路的引數做一些工作。比如模型壓縮。