1. 程式人生 > >透徹理解神經網路剪枝演算法

透徹理解神經網路剪枝演算法

1. 問題敘述
  心血來潮寫點最近做的成果,主要分成兩個部落格來進行闡述。研究生上了一年半看了不少關於剪枝神經網路方面的文章,但是有很少的文章能讓人感覺到耳目一新,打通了任督二脈的感覺。前段時間看到了一個剪枝演算法就有這種感覺。和大家分享下。
  全連線神經網路在很多方面都用的很多,這我就不贅述了,全連線有很強的逼近能力但是很容易導致過擬合。所以 機器學習與模式識別最核心的問題就是減小系統的複雜度(description -length【1】,VC-dimensions【2】),在神經網路中,這樣的核心問題就變成了減少連線權值的數量。
  減小模型複雜度方面,很常見的方法是在損失函式後面加上懲罰項

這裡寫圖片描述

  為了是模型複雜度減小,通常使用2範數【3】
這裡寫圖片描述

  但是2範數誤差使權值和閾值稀疏化,進而使用1範數對權值和閾值進行懲罰
這裡寫圖片描述

  進而陸續有很多關於懲罰項的改進【4】【5】【6】,接下來我要闡述的演算法是OBS演算法,很簡單,並且很容易實現,效果顯著。
2. OBS演算法詳述【7】
  OBS演算法是一種基於Hessian矩陣的網路修剪演算法,首先,構造誤差曲面的一個區域性模型,分析權值的擾動所造成的影響。
  通過對誤差函式進行Taylor展開
這裡寫圖片描述

H為Hessian矩陣,T表示矩陣的轉置,w為神經網路中的引數(包括權值和閾值), E為訓練集的訓練誤差,訓練神經網路用任意的優化演算法,該剪枝演算法都是適用的。
通過優化演算法(如L-M演算法)得到一個區域性最小點,則上式第一項為0,忽略第三項高階無窮小項。可以得到
這裡寫圖片描述

該方法通過將其中一個權值置為0,從而可以寫成

這裡寫圖片描述

這裡寫圖片描述為單位向量,只有在第q項為1其他的項為0。
當其中一個權值或者是閾值置為0時,使 這裡寫圖片描述最小,可以得到
這裡寫圖片描述

通過拉格朗日乘子法,可以將有約束優化問題轉化為無約束優化問題,
這裡寫圖片描述

這裡寫圖片描述為拉格朗日乘子,通過對函式 這裡寫圖片描述求偏導,可以得到 這裡寫圖片描述
這裡寫圖片描述導致誤差的變化為
這裡寫圖片描述

演算法流程圖如下
這裡寫圖片描述

3. 感想

  1.OBS演算法的全稱為optimal brain surgeon,翻譯成中文就是最優外科手術,表面的意思就是該方法是和神經網路過程是分開的。
  2.該方法是一種框架,只要是模型能求出引數的梯度,那麼都可用這個方法進行稀疏化。
4. 例子
   y=sin(x) 生成100個樣本,然後隨機生成(0,1)的噪聲加到乾淨的樣本上
A: 用全連線神經網路對y=sin(x)函式近似,如圖,隱層節點為17個
這裡寫圖片描述

B:通過剪枝演算法得到的網路為
這裡寫圖片描述

通過人為化簡為
這裡寫圖片描述

從上圖可以看出,全連線神經網路用來對函式y=sin(x)進行逼近只需要4個隱層節點,所以該演算法可以將多餘的隱層節點去掉,並且可以進行特徵選擇,將噪聲去掉。
5. 引用
[1] Barron, A., Rissanen, J., & Yu, B. (1998). The minimum description length principle in coding and modeling. IEEE Transactions on Information Theory, 44(6), 2743-2760.
[2] Vapnik, V. N., & Chervonenkis, A. Y. (2015). On the uniform convergence of relative frequencies of events to their probabilities Measures of complexity (pp. 11-30): Springer.
[3] Chow, M.-Y., & Teeter, J. (1994). An analysis of weight decay as a methodology of reducing three-layer feedforward artificial neural networks for classification problems. Paper presented at the Neural Networks, 1994. IEEE World Congress on Computational Intelligence., 1994 IEEE International Conference on.
[4] Weigend, A. S., Rumelhart, D. E., & Huberman, B. A. (1991). Generalization by weight-elimination with application to forecasting. Paper presented at the Advances in neural information processing systems.
[5] Hoyer, P. O. (2004). Non-negative matrix factorization with sparseness constraints. Journal of machine learning research, 5(Nov), 1457-1469.
[6] Zeng, H., & Trussell, H. J. (2010). Constrained dimensionality reduction using a Mixed-Norm penalty function with neural networks. IEEE Transactions on Knowledge and Data Engineering, 22(3), 365-380.
[7] Hassibi, B., & Stork, D. G. (1993). Second order derivatives for network pruning: Optimal brain surgeon. Paper presented at the Advances in neural information processing systems.