1. 程式人生 > >mxnet下如何檢視中間結果

mxnet下如何檢視中間結果

https://blog.csdn.net/disen10/article/details/79376631

檢視權重

在訓練過程中,有時候我們為了debug而需要檢視中間某一步的權重資訊,在mxnet中,我們可以很方便的呼叫get_params()方法來得到權重資訊。

  1.   '''
  2.   檢視權重示例程式碼
  3.   轉載時註明地址:http://blog.csdn.net/u010414386?viewmode=contents
  4.   '''
  5.   import mxnet as mx
  6.   sym, arg_params, aux_params = mx.model.load_checkpoint( 'resnet-50',0)#載入模型
  7.   mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #建立Module
  8.   mod.bind(for_training= False,data_shapes=[('data',(1,3,224,224))]) #繫結,此程式碼為預測程式碼,所以training引數設為False
  9.   mod.set_params(arg_params,aux_params)
  10.   import numpy as np
  11.   import cv2
  12.   def get_image(filename):
  13.   img = cv2.imread(filename)
  14.   img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
  15.   img = cv2.resize(img,( 224,224))
  16.   img = np.swapaxes(img, 0,2)
  17.   img = np.swapaxes(img, 1,2)
  18.   img = img[np.newaxis,:]
  19.   return img
  20.   from collections import namedtuple
  21.   Batch = namedtuple( 'Batch',['data'])
  22.   img = get_image( 'val_1000/0.jpg') #獲取圖片
  23.   mod.forward(Batch([mx.nd.array(img)])) #預測結果
  24.   ################################################
  25.   #debug模式下,獲取權重資訊
  26.   keys = mod.get_params()[ 0].keys() # 列出所有權重名稱
  27.   conv_w = mod.get_params()[ 0]['conv0_weight'] #獲取想要檢視的權重資訊,如conv_weight
  28.   print conv_w.asnumpy() #檢視具體數值
  29.   ################################################
  30.   prob = mod.get_outputs()[ 0].asnumpy()
  31.   y = np.argsort(np.squeeze(prob))[:: -1]
  32.   print( 'truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

檢視中間輸出結果

由於mxnet的網路由symbol組成,而symbol又屬於符號式程式設計,所以我們不能像上面檢視權重一樣直接檢視,我們需要把我們想看的輸出結果儲存下來。

  1.   '''
  2.   方法一
  3.   檢視中間結果程式碼
  4.   轉載時註明地址:http://blog.csdn.net/u010414386?viewmode=contents
  5.   '''
  6.   import mxnet as mx
  7.   net = mx.symbol.Variable( 'data')
  8.   fc1 = mx.symbol.FullyConnected(data=net, name= 'fc1', num_hidden=128)
  9.   net = mx.symbol.Activation(data=fc1, name= 'relu1', act_type="relu")
  10.   net = mx.symbol.FullyConnected(data=net, name= 'fc2', num_hidden=64)
  11.   out = mx.symbol.SoftmaxOutput(data=net, name= 'softmax')
  12.   # 通過把兩個輸出組成一個group來得到自己需要檢視的中間層輸出結果
  13.   group = mx.symbol.Group([fc1, out])
  14.   print group.list_outputs()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  1.   '''
  2.   方法二
  3.   有時候我們使用別人的模型,所以無法像方法一一樣在定義模型的時候就確定需要檢視的中間層輸出結果,
  4.   這時候我們使用get_internals()方法來查詢自己需要檢視的中間層
  5.   轉載時註明地址:http://blog.csdn.net/u010414386?viewmode=contents
  6.   '''
  7.   import mxnet as mx
  8.   sym, arg_params, aux_params = mx.model.load_checkpoint( 'resnet-50',0)#載入模型
  9.   ########################################################################
  10.   args = sym.get_internals().list_outputs() #獲得所有中間輸出
  11.   internals = model.symbol.get_internals()
  12.   fc1 = internals[ 'fc1_output']
  13.   conv = internals[ 'stage4_unit3_conv1_output']
  14.   group = mx.symbol.Group([fc1, sym, conv]) #把需要輸出的結果按group方式組合起來,這樣就可以得到中間層的輸出
  15.   #########################################################################
  16.   mod = mx.mod.Module(symbol=group,context=mx.gpu()) #建立Module
  17.   mod.bind(for_training= False,data_shapes=[('data',(1,3,224,224))]) #繫結,此程式碼為預測程式碼,所以training引數設為False
  18.   mod.set_params(arg_params,aux_params)
  19.   import numpy as np
  20.   import cv2
  21.   def get_image(filename):
  22.   img = cv2.imread(filename)
  23.   img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
  24.   img = cv2.resize(img,( 224,224))
  25.   img = np.swapaxes(img, 0,2)
  26.   img = np.swapaxes(img, 1,2)
  27.   img = img[np.newaxis,:]
  28.   return img
  29.   from collections import namedtuple
  30.   Batch = namedtuple( 'Batch',['data'])
  31.   img = get_image( 'val_1000/0.jpg') #獲取圖片
  32.   mod.forward(Batch([mx.nd.array(img)])) #預測結果
  33.   prob = mod.get_outputs()[ 0].asnumpy()
  34.   y = np.argsort(np.squeeze(prob))[:: -1]
  35.   print( 'truth label %d; top-1 predict label %d' % (val_label[0], y[0]))