1. 程式人生 > >【深度學習】基於計算圖的反向傳播詳解

【深度學習】基於計算圖的反向傳播詳解

計算圖

計算圖就是將計算過程用圖形表示出來,這裡所說的圖形是資料結構圖,通過多個節點和邊表示(邊是用來連線節點的)。

下面我們先來通過一個簡單的例子瞭解計算圖的計算過程

假設我們有如下需求:

  • 一個蘋果100塊錢,一個橘子150塊錢
  • 消費稅為10%
  • 買了2個蘋果,3個橘子,一共需要支付多少錢?

1、根據需要構建計算圖

在這裡插入圖片描述

2、在計算圖上從左向右進行計算

  • 按著圖中箭頭方向“從左向右進行計算”稱為正向傳播,即從計算圖的出發點到結束點的傳播
  • 自然,“從右往左計算”稱為反向傳播

3、區域性計算

  • 上圖中,我們對於蘋果和橘子的計算是分開的,然後合併。這裡的區域性也就是在計算過程中只將與自己相關的資訊進行計算輸出結果;
  • 計算圖可以集中精力於區域性計算,無論全域性多麼複雜,各個步驟說要做的就是物件節點的區域性計算,這樣就可以通過區域性計算,將結果傳遞下去,就可以獲得全域性的複雜計算結果

4、計算圖求解的好處

  • 區域性計算,可以是各個節點只致力於簡單的計算,從而簡化問題
  • 利用計算圖可以將中間的計算結果儲存起來,以免重複計算
  • 可以通過反向傳播高效計算導數這一點是最重要的

我們來考慮一個問題:

  • 假設我們想知道蘋果價格的上漲會在多大程度上影響最終的支付金額,即求“支付金額關於蘋果的價格的導數”
  • 設蘋果價格為 x
    x
    ,支付金額為 L L ,則相當於求 L
    x \frac{\partial L}{\partial x}

在這裡插入圖片描述

圖中,反向傳播“區域性導數”,將導數的值解除安裝箭頭下方,圖中紅色表示
從右向左(1 -> 1.1 -> 2.2 )
這意味著,如果蘋果的價格上漲1塊錢,最終的支付金額會增加2.2塊錢
關於如何計算的,後面會介紹
同樣的我們也可以計算出“支付金額關於蘋果個數的導數”、“支付金額關於消費稅的導數”

鏈式法則

  • 關於偏導數的鏈式法則,在高等數學中有具體內容,如果不知道可以參考下
  • 從上面的計算我們知道正向傳播計算過程就是我們日常的計算過程,所以很容易理解
  • 反向傳播區域性導數的原理,就是基於鏈式法則的

1、計算圖的方向傳播

在這裡插入圖片描述

反向傳播的計算順序:

  • 將訊號 E E 乘以節點的區域性導數 y x \frac{\partial y}{\partial x}
  • 傳遞給下一個節點

通過這樣的計算,可以高效地求出導數的值

2、鏈式法則和計算圖

z = t 2 t = x + y z = t^2 \\ t = x+y

在這裡插入圖片描述

反向傳播的計算順序:

  • 將節點的輸入訊號乘以節點的區域性導數(偏導數),然後傳遞給下一個節點
    比如,反向傳播時,“**2”節點的輸入時 z z \frac{\partial z}{\partial z} ,將其乘以區域性導數 z t \frac{\partial z}{\partial t}
  • 然後再將上一步的輸出 z z z t \frac{\partial z}{\partial z} \frac{\partial z}{\partial t} 作為下一節點的輸入,同樣乘以區域性導數 t x \frac{\partial t}{\partial x}

根據鏈式法則:

  • z z z t t x = z t t x = z x \frac{\partial z}{\partial z} \frac{\partial z}{\partial t} \frac{\partial t}{\partial x} = \frac{\partial z}{\partial t} \frac{\partial t}{\partial x} = \frac{\partial z}{\partial x}
    對應於“ z z 關於 x x 的導數”

反向傳播

1、加法節點的反向傳播實現

這裡以 z = x + y z = x+y 為物件來說明

z x = 1 \frac{\partial z}{\partial x} = 1
z y = 1 \frac{\partial z}{\partial y} = 1

在這裡插入圖片描述

如圖, z x = 1 z y = 1 \frac{\partial z}{\partial x} = 1 ,\frac{\partial z}{\partial y} = 1
反向傳播將上游傳過來的導數乘以1,然後傳向下遊

也就是說,因為加法節點的反向傳播只乘以1,所以輸入的值會原封不動地流向下一個節點

  • python實現加法層
class AddLayer:
    def __init__(self):
        pass
	# 正向傳播
    def forward(self, x, y):
        out = x + y

        return out
	# 反向傳播
    def backward(self, dout):
        dx = dout * 1
        dy = dout * 1

        return dx, dy

2、乘法節點的反向傳播實現

這裡以 z = x y z = xy 為物件來說明

L z = y \frac{\partial L}{\partial z} = y
z y = x \frac{\partial z}{\partial y} = x

在這裡插入圖片描述

乘法節點的反向傳播需要正向傳播時的輸入訊號值,因此,實現乘法節點的反向傳播時,需要儲存正向傳播的輸入訊號

乘法節點的反向傳播會乘以輸入訊號的翻轉值

  • python實現乘法層
class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None
	# 正向傳播
    def forward(self, x, y):
        self.x = x
        self.y = y                
        out = x * y

        return out
	# 反向傳播
    def backward(self, dout):
        dx = dout * self.y
        dy = dout * self.x

        return dx, dy

# 測試
apple = 100 # 蘋果價格
apple_num = 2 # 蘋果個數
tax = 1.1 # 消費稅

mul_apple_layer = MulLayer() # 建立乘法器物件
mul_tax_layer = MulLayer() # 建立乘法器物件

# forward
apple_price = mul_apple_layer.forward(apple, apple_num) # 2個蘋果的價格
price = mul_tax_layer.forward(apple_price, tax) # 支付金額

# backward
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice) 
dapple, dapple_num = mul_apple_layer.backward(dapple_price)

print("price:", int(price))
print("dApple:", dapple)
print("dApple_num:", int(dapple_num))
print("dTax:", dtax)
輸出為:
price: 220
dApple: 2.2
dApple_num: 110
dTax: 200

在這裡插入圖片描述

結果與上圖中的反向傳播的結果一樣

  • 蘋果和橘子的例子的實現
apple = 100
apple_num = 2
orange = 150
orange_num = 3
tax = 1.1

# layer
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()

# forward
apple_price = mul_apple_layer.forward(apple, apple_num)  # (1)
orange_price = mul_orange_layer.forward(orange, orange_num)  # (2)
all_price = add_apple_orange_layer.forward(apple_price, orange_price)  # (3)
price = mul_tax_layer.forward(all_price, tax)  # (4)

# backward
dprice = 1
dall_price, dtax = mul_tax_layer.backward(dprice)  # (4)
dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)  # (3)
dorange, dorange_num = mul_orange_layer.backward(dorange_price)  # (2)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)  # (1)

print("price:", int(price))
print("dApple:", dapple)
print("dApple_num:", int(dapple_num))
print("dOrange:", dorange)
print("dOrange_num:", int(dorange_num))
print("dTax:", dtax)
輸出為:
price: 715
dApple: 2.2
dApple_num: 110
dOrange: 3.3000000000000003
dOrange_num: 165
dTax: 650

在這裡插入圖片描述

啟用函式層的反向傳播實現

1、ReLU層

  • 數學表示式:
    y = { x ( x > 0 ) 0 ( x 0 ) y = \begin{cases} x \quad (x > 0) \\ 0 \quad (x \leqslant0) \end{cases}