1. 程式人生 > >前向傳播與反向傳播

前向傳播與反向傳播

本篇主要介紹神經網路的引數更新方法

 

在介紹引數更新方法之前,需要知道損失函式(loss function)。

損失函式的作用是衡量模型預測值與實際值之間的差異。

一般神經網路用的損失函式是:交叉熵損失(cross entropy)。

當pi=1時,函式影象如下:

x軸為qi,y軸為loss值(x軸的定義域一般在[0,1]之間)

上面公式pi指模型的實際值,qi值的是通過神經網路得到的預測值。n為模型預測值qi擁有的類別數量。上篇講過一般輸出層為一維向量,對應的loss就是pi和qi的點乘再求和,因為一般預測的pi是one-hot型別,只有一個值為1,其餘都為0。所以loss是pi為1的對應的log(qi)。可以看出如果qi接近於1,那麼loss的值就為0(如上面函式影象)。所以我們的目的就是使所有訓練樣本loss的值接近0。我們通過更新權重w來使模型的loss接近0,那麼這些w到底如何更新才算最好?從上面影象看(定義域[0,1]之間),如果qi=1的值是函式的最小值,求w使得函式loss最小,有木有很熟悉。沒錯,這就需要用到偏導數了!!即梯度下降演算法

 

前向傳播與反向傳播

前向傳播:沿著神經網路輸入到輸出,更新權重w

反向傳播:沿著神經網路輸出到輸入,更新權重w

其實這兩種引數更新方法的本質都是基於梯度下降演算法。只是由於神經網路的層數過多,引數量過大,再使用前向傳播時,計算速度和資源消耗過大。這個時候才出現了反向傳播演算法。

在瞭解前向傳播和反向傳播之前,需要了解複合函式求導法。因為神經網路就可以寫成一個複合函式,我們要求使得loss最小的每個w,就需要計算每個w的偏導數。

一般網路都是需要對loss求導,但下面為了方便我只用優化S(即目標y)

整個網路的複合函式為:

對權重w求導可以抽象為:

其中δi(3)為上一步求導結果,aj2為前向傳播計算的結果。

以上述網路結構示意,帶入w求導公式為:

可以看到δ(2)是δ(s)/δ(w2)的結果,之前已經計算過了,而a(1)也是在計算y的預測值時計算過了。

所以反向傳播演算法是,先進行前向傳播計算a(1),a(2)的值,並存儲在記憶體中。然後從更新輸出層到倒數第一個隱藏層之間的權重,將計算的梯度儲存到記憶體中,再更新倒數第一個隱藏層到倒數第二個隱藏層之間的梯度,這個時候需要用到之前儲存的梯度,計算方式為之前儲存的梯度*倒數第二隱藏層計算的a(n-1)的值。

使用反向傳播比直接從輸入層到輸出層更新,計算量要小很多。直接前向更新引數,會做許多重複的運算,白白浪費計算資源。

 

本文主要參考了stanford的CS224N關於反向傳播的例子。

CS224N網址:https://ai.yanxishe.com/page/VideoTranslation/157