1. 程式人生 > >淺談神經網路中的啟用函式

淺談神經網路中的啟用函式

       啟用函式是神經網路中一個重要的環節,本文將介紹為什麼神經網路網路要利用啟用函式,幾種常用的啟用函式(邏輯函式Sigmoid、雙曲正切函式tanh、線性整流函式(ReLU),神經網路中的梯度消失問題和ReLU如何避免梯度消失。

1 用啟用函式的原因

       如果神經網路沒有進行可以提取非線性特徵的卷積操作,而且該神經網路也不用啟用函式,那麼這個神經網路第i層輸出只有Wxi+b。這樣此神經網路不論有多少層,第i層的輸出都是一個關於第i層輸入xi的線性組合,相當於此時多層神經網路退化為一個多層的線性迴歸模型,難以學習如影象、音訊、文字等複雜資料的特徵。

       正因為這個原因,神經網路要引入啟用函式來給神經網路增加一些非線性的特性,所以目前常見的啟用函式大多是非線性函式。這樣神經網路中下一層得到的輸入不再是線性組合了。

2 常見的啟用函式

2.1 邏輯函式Sigmoid [1]

       邏輯函式(logistic function)或邏輯曲線(logistic curve)是一種常見的S函式,它是皮埃爾·弗朗索瓦·韋呂勒在1844或1845年在研究它與人口增長的關係時命名的。

       一個簡單的Logistic函式表示式為:

\[ f\left( x \right) = \frac{1}{{1 + {e^{ - x}}}} \]


圖1 標準邏輯函式的影象

       邏輯函式形如S,所以通常也叫做S形函式。

       從函式影象易知f(x)的定義域為[-∞, +∞], 值域是(0,1)

       對f(x)求導數,易得

\[f'\left( x \right) = {\left( {\frac{1}{{1 + {e^{ - x}}}}} \right)^\prime } = \frac{{{e^{ - x}}}}{{{{\left( {1 + {e^{ - x}}} \right)}^2}}}\;\; = f\left( x \right)\left( {1 - f\left( x \right)} \right)\]

2.2 雙曲正切函式tanh [2]

       雙曲正切函式是雙曲函式的一種。在數學中,雙曲函式是一類與常見的三角函式類似的函式。雙曲正切函式的定義為

\[f\left( x \right) = \tanh \left( x \right) = \frac{{{e^x} - {e^{ - x}}}}{{{e^x} + {e^{ - x}}}}\]


圖2 雙曲正切函式的影象(同邏輯函式類似)

       從函式影象易知f(x)的定義域為[-∞, +∞], 值域是(-1,1)

       對f(x)求導數,易得

\[f'\left( x \right) = {\left( {\frac{{{e^x} - {e^{ - x}}}}{{{e^x} + {e^{ - x}}}}} \right)^\prime } = \frac{4}{{{{\left( {{e^x} + {e^{ - x}}} \right)}^2}}}\;\; = 1 - f{\left( x \right)^2}\]

2.3 線性整流函式ReLU [3]

       線性整流函式(Rectified Linear Unit, ReLU),又稱修正線性單元, 是一種人工神經網路中常用的啟用函式,通常指代以斜坡函式及其變種為代表的非線性函式。

       通常意義下,線性整流函式指代數學中的斜坡函式,即

\[f\left( x \right) = \left\{ \begin{array}{l} x\quad \quad x \ge 0 \\ 0\quad \quad x < 0 \\ \end{array} \right.\]


圖3 ReLU函式影象

       從函式影象易知f(x)的定義域為[-∞, +∞], 值域是[0, +∞)

       對f(x)求導數,易得

\[f'\left( x \right) = \left\{ \begin{array}{l} 1\quad \quad x \ge 0 \\ 0\quad \quad x < 0 \\ \end{array} \right.\]

3 梯度消失問題和ReLU如何處理此問題

       使用S形函式作為啟用的神經網路中,隨著神經網路的層數增加,神經網路後面層在梯度下降中求導的梯度幾乎為0,從而導致神經網路網路後面層的權值矩陣幾乎無法更新。表現為隨著隱藏層數目的增加,分類準確率反而下降了。這種現象叫做消失的梯度問題。

       假設神經網路只有三層,用S型函式作為啟用函式

       第一層輸入為x, 輸出為S(W1x+b1)

       第二層輸入為S(W1x+b1),輸出為S(W2S(W1x+b1)+b2)

       第三層輸入為S(W2S(W1x+b1)+b2),輸出為S(W3S(W2S(W1x+b1)+b2)+b3)

       同時簡記住每層在啟用函式處理前的值為ai, 輸出為fi

       假設最後損失函式為L,L是一個關於f3的函式,那麼求導易得

\[\begin{array}{l} \frac{{\partial L}}{{\partial {W_1}}} = \frac{{\partial L}}{{\partial {f_3}}} \cdot \frac{{\partial S\left( {{W_3}S\left( {{W_2}S\left( {{W_1}x + {b_1}} \right) + {b_2}} \right) + {b_3}} \right)}}{{\partial {W_1}}} \\ \quad \quad = \frac{{\partial L}}{{\partial {f_3}}} \cdot \frac{{\partial S}}{{\partial {a_3}}} \cdot \frac{{\partial {W_3}S\left( {{W_2}S\left( {{W_1}x + {b_1}} \right) + {b_2}} \right) + {b_3}}}{{\partial {W_1}}} \\ \quad \quad = \frac{{\partial L}}{{\partial {f_3}}} \cdot \frac{{\partial S}}{{\partial {a_3}}} \cdot {W_3} \cdot \frac{{\partial S\left( {{W_2}S\left( {{W_1}x + {b_1}} \right) + {b_2}} \right)}}{{\partial {W_1}}} \\ \quad \quad = \cdots \\ \quad \quad = \frac{{\partial L}}{{\partial {f_3}}} \cdot \frac{{\partial S}}{{\partial {a_3}}} \cdot {W_3} \cdot \frac{{\partial S}}{{\partial {a_2}}} \cdot {W_2} \cdot \frac{{\partial S}}{{\partial {a_1}}} \cdot \frac{{\partial {a_1}}}{{\partial {W_1}}} \\ \end{array}\]

       其中偏導數∂S/ ∂ai是造成梯度消失的原因,因為S函式的導數閾值為

\[f'\left( x \right) = \frac{{{e^{ - x}}}}{{{{\left( {1 + {e^{ - x}}} \right)}^2}}}\;\; \in \left( {0,\left. {\frac{1}{4}} \right]} \right.\]

       即有0<∂S/ ∂a1≤0.25, 0<∂S/ ∂a2≤0.25, 0<∂S/ ∂3≤0.25, 在損失函式偏導表示式中三個偏導數相乘有:

\[0 < \frac{{\partial S}}{{\partial {a_3}}}\frac{{\partial S}}{{\partial {a_2}}}\frac{{\partial S}}{{\partial {a_1}}} \le 0.015625\]

       這樣會減小損失函式的數值,如果神經網路是20層,則有

\[0 < \frac{{\partial S}}{{\partial {a_{20}}}}\frac{{\partial S}}{{\partial {a_{19}}}} \cdots \frac{{\partial S}}{{\partial {a_1}}} \le {0.25^{20}} = {\rm{9}}.0{\rm{94}} \times {10^{ - 13}}\]

       這是一個更小的數,所以神經網路後幾層求第一層引數W1的梯度就非常小。而ReLU函式就是為了避免梯度消失問題,因為ReLU求導只有兩個值1或0,這樣的話只要神經網路梯度中一條路徑上的導數都是1,那麼無論網路有多少層,網路後幾層的梯度都可以傳播到網路前幾層。

參考資料

  1. https://en.wikipedia.org/wiki/Logistic_function
  2. https://en.wikipedia.org/wiki/Hyperbolic_function
  3. https://en.wikipedia.org/wiki/Rectifier_(neural_networks)