Mxnet-Python API 學習——Module API
Mxnet-Python API中包含了什麼?
包含了兩個主要的高層庫:Gluon API和Module API。這兩個高層包的基礎是NDArray(命令式)和Symbol(宣告式)。當然,還有一些其他的API,如Autograd API,IO API等。 本文主要對Module API的用法做了總結: Module 中提供了基於Symbol進行計算的中層和高層的介面,模組化了training和inference過程中經常需要用到的程式碼。例如,在訓練一個神經網路時,我們需要輸入訓練資料,初始化模型引數,接著進行前向和反向傳播,更新引數,儲存和恢復模型引數等。Module則把所有這些過程都封裝在一起了。下面通過一個具體的例子來體會這個過程。
1.準備資料
import mxnet as mx import numpy as np import logging logging.getLogger().setLevel(logging.INFO) mx.random.seed(1234) fname = mx.test_utils.download('https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data') data = np.genfromtxt(fname, delimiter=',')[:,1:] label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])#ord()函式返回ASCII碼值 batch_size=32 ntrain=int(data.shape[0]*0.8) train_iter=mx.io.NDArrayIter(data[:ntrain,:],label[:ntrain],batch_size,shuffle=True) val_iter=mx.io.NDArrayIter(data[ntrain:,:],label[ntrain:],batch_size)
2.使用Symbol定義網路
net=mx.symbol.Variable('data') net=mx.symbol.FullyConnected(net,name='fc1',num_hidden=64) net=mx.symbol.Activation(net,name='relu1',act_type="relu") net=mx.symbol.FullyConnected(net,name='fc2',num_hidden=26) net=mx.symbol.SoftmaxOutput(net,name='softmax') mx.viz.plot_network(net).view()#視覺化我們建立的網路
3.建立Module
我們可以通過指定下面的引數來建立Module: (1)symbol——定義的網路(也就是我們使用symbol定義的net網路) (2)context——執行計算過程使用的裝置或者裝置列表 (3)data_names——輸入資料變數名稱列表(網路中的’data’) (4)label_names——輸入標籤變數名稱列表 (網路中的softmax_label)
mod=mx.mod.Module(symbol=net,context=mx.cpu(),data_names=['data'],label_names=['softmax_label'])
4.1 使用中層的介面來執行Module
要想訓練Module,通常需要完成以下的步驟: (1)bind——根據資料和標籤的形狀,系統為執行計算過程中所需要的環境分配記憶體; (2)init_params——引數初始化 (3)init_optimizer——設定優化器,預設是SGD (4)metric.create——設定用於評估的度量方式 (5)forward——前向傳播 (6)update_metric——評估並累積上一次前向計算的輸出的評估度量。 (7)backward——反向傳播 (8)update——更新網路引數
#根據輸入資料和標籤的形狀來分配記憶體
mod.bind(data_shapes=train_iter.provide_data,label_shapes=train_iter.provide_label)
#引數初始化
mod.init_params(initializer=mx.init.Uniform(scale=0.1))
#設定優化器
mod.init_optimizer(optimizer='sgd',optimizer_params=(('learning_rate',0.1),))
#使用準確率作為測量準則
metric=mx.metric.create('acc')
#開始訓練
for epoch in range(5):
train_iter.reset()
metric.reset()
for batch in train_iter:
mod.forward(batch,is_train=True)#前向傳播
mod.update_metric(metric,batch.label)#更新度量
mod.backward()#後向傳播
mod.update()#更新引數
print('Epoch %d ,training %s' %(epoch,metric.get()))
4.2 使用高層的介面執行Module
4.1中列出的全部操作都可以使用一次fit函式來完成,而不用自己寫多個繁瑣的步驟
train_iter.reset()#重置訓練資料迭代器 train_iter
mod.fit(train_iter,eval_data=val_iter,optimizer='sgd',optimizer_params={'learning_rate':0.1}, eval_metric='acc',num_epoch=8)
5 使用Module進行預測
y=mod.predict(val_iter)#生成預測的輸出值
assert y.shape==(4000,26)
score=mod.score(val_iter,['acc'])#直接獲得評估的準確率
print("Accuracy score is %f" %(score[0][1]))
assert score[0][1]>0.77, "Achieved accuracy (%f) is less than 0.77" %(score[0][1])
6儲存和載入Module的引數
使用checkpoint callback 我們可以讓Module在每一次訓練epoch之後都儲存一次引數
model_prefix='mx_mlp'
checkpoint=mx.callback.do_checkpoint(model_prefix)
mod.fit(train_iter,num_epoch=5,epoch_end_callback=checkpoint)
使用load_checkpoint函式來載入symbol和引數,然後把載入的引數灌入到Module中
#arg_param: 模型引數,以及網路權重字典。
#aux_params: 模型引數,以及一些附加狀態的字典
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson()
mod.set_params(arg_params, aux_params)
如果想從恢復的斷點處繼續訓練模型,可以呼叫fit函式把加載出的引數灌入到Module中,並在此基礎上繼續訓練模型
mod.fit(train_iter,num_epoch=21,arg_params=arg_params,aux_params=aux_params,begin_epoch=3)