1. 程式人生 > >深度學習入門—BP演算法簡單實現

深度學習入門—BP演算法簡單實現

"""
BP演算法的簡單實現,這裡只有三層網路,目的在於說明其執行過程
除錯時可以控制輸入的迭代次數和學習率,這樣可以動態地看執行效果
當迭代次數過大時,會出現過擬合情況,親測
"""
import numpy as np
def sigmoid(x): #設定啟用函式
    return 1 / (1 + np.exp(-x))
def sigmoidDerivationx(y): #計算啟用函式的偏微分
    return y * (1 - y)
if __name__ == "__main__":
    alpha = 0.05 # 學習率,一般在(0,0.1)上取值
    numIter = 100000 # 迭代次數
    w1 = [[0.15, 0.20], [0.25, 0.30]]  # 輸入層的權重
    w2 = [[0.40, 0.45], [0.50, 0.55]]  # 權重矩陣的維數和輸入輸出維數有關係 要滿足矩陣相乘的條件
    b1 = 0.35 # 初始化偏置 可以已知,也可設定為未知
    b2 = 0.60
    # 你心裡應該清楚,實際生產中,輸入與輸出都是從對應的資料集中獲取得到的,這裡僅作為演示
    x = [0.05,0.10] # 初始化輸入
    y = [0.01,0.99] # 初始化對應的輸出label
    z1 = np.dot(w1,x) + b1
    a1 = sigmoid(z1) # 啟用函式,第一層激勵值
    z2 = np.dot(w2,x) + b2
    a2 = sigmoid(z2) # 同上,第二層激勵值
    for n in range(numIter): # 開始迭代
        # 反向傳播 使用代價函式為C=1 / (2n) * sum[(y-a2)^2]
        # 最後一層的梯度delta
        delta2 = np.multiply(-(y - a2),np.multiply(a2,1-a2))# 注意delta2這兩項乘積的由來,第一個引數就是代價函式對a2求的偏導(實際上得到的就是真實輸出值與預測輸出值之間的誤差而已),第二個引數是激勵值的導數,二者相乘最後得梯度的變化值
        # 若不明白這個過程,可以直接記住結論,感興趣的可以參考相關的詳細證明過程
        # 非最後一層的梯度delta(即隱含層),演算法與最後一層不一樣,因為其沒有輸出預測值,我們用下一層誤差的加權和來代替真實值與預測值的差
        delta1 = np.multiply(np.dot(np.array(w2).T, delta2), np.multiply(a1, 1 - a1))
        # 關鍵是要明白計算的過程與理論基礎
        # 計算完權重的變化後(即delta),更新權重,delta也可以稱為梯度的變化
        for i in range(len(w2)):
            w2[i] = w2[i] - alpha * delta2[i] * a1
        for i in range(len(w1)):
            w1[i] = w1[i] - alpha * delta1[i] * np.array(x)
        # 繼續前向傳播,算出誤差值
        z1 = np.dot(w1, x) + b1 # 用新的權重值再求一遍激勵
        a1 = sigmoid(z1)
        z2 = np.dot(w2, a1) + b2
        a2 = sigmoid(z2)
        print(str(n) + " result:" + str(a2[0]) + ", result:" + str(a2[1]))# 輸出迭代後的預測值
        print(str(n) + "  error1:" + str(y[0] - a2[0]) + ", error2:" +str(y[1] - a2[1])) # 輸出誤差,可以通過改變迭代次數來檢視效果




二、相關分析 麻雀雖小五臟俱全,可以通過除錯上面程式碼的引數,動態的觀看執行過程以及誤差的變化情況。 這個資料集是自己創造的,很小很小,沒有挑戰性。在看神經網路與深度學習這本書的時候,上面有一個識別手寫數字的demo,這個demo也有反向傳播實現的版本,是基於經典的mnist資料集來實現的。作者用的是python2.7的環境,而且對於初學者來說,徹底讀懂需要下點功夫,我學習的時候,用python3.6重新修改了一遍,而且在原始碼的基礎上增加了自己的理解註釋,已經上傳到了我的github,需要的可以參考:https://github.com/GritCoder/BP