1. 程式人生 > >【ML模型詳細推導3】- 感知機

【ML模型詳細推導3】- 感知機

感知機

1. 一句話介紹感知機

· 一個二分類的線性分類模型。
· 感知機學習旨在求出將訓練集資料進行線性劃分的分類超平面,匯入基於誤分類的損失函式,利用梯度下降法對損失函式進行極小化,求得模型。
· 分為原始形式和對偶形式。

2. 模型介紹(原始形式)

前提:資料集要求線性可分性
- 模型:
f (

x ) = s i g n ( w x
+ b ) f(x)=sign(wx+b)
s i g n ( x ) = { + 1 , x 0 1 x < 0 其中,sign(x)=\left\{\begin{matrix} +1,& x\geqslant 0 & \\ -1 & x<0& \end{matrix}\right.

- 損失函式(學習策略)
損失函式,誤分類點到超平面的“函式間隔”和:
L ( w , b ) = x i M y i ( w x i + b ) L(w,b) = -\sum_{x_i\in M}y_i (wx_i+b) 其中, M M 為誤分類的集合。

函式間隔待了解
(Q1:什麼是函式間隔,和幾何間隔的區別)

- 隨機梯度下降(學習演算法)

隨機梯度下降法,最小化損失函式 L ( w , b ) L(w,b)
梯度:
w L ( w , b ) = x i M y i x i \triangledown _wL(w,b)=-\sum _{x_i\in M}y_i x_i \\ b L ( w , b ) = x i M y i \triangledown _bL(w,b)=-\sum _{x_i\in M}y_i
隨機選取一個數據點 ( x i , y i ) (x_i,y_i) ,判斷其是否誤分類,也就是判斷 y i ( w x + b ) < 0 y_i\cdot(wx+b)<0 與否
y i f ( x ) < 0 y_i f(x)<0 ,對 w , b w,b 進行更新:
w w + η y i x i w\leftarrow w+\eta y_ix_i b b + η y i b\leftarrow b+\eta y_i 其中, η \eta 為學習率。
感知機學習演算法是誤分類驅動的,這樣,因為資料線性可分,通過迭代可以直至訓練集中沒有誤分類點。

3. 感知機的對偶形式

- 模型:
將訓練集 { ( x 1 , y 1 ) ( x 2 , y 2 ) ( x i , y i ) } \{(x_1,y_1)(x_2,y_2)…(x_i,y_i)\} 全部代入進行隨機下降,若第 i 個樣本用的次數為 n i n_i ,則最後訓練出的 w w b b
w = i = 1 N n i η y i x i w = \sum _{i=1}^{N}n_i\eta y_ix_i b = i = 1 N n i η y i b = \sum _{i=1}^{N}n_i\eta y_i

w w b b 帶入感知機原始形式得對偶形式:
f ( x ) = s i g n ( j = 1 N n j η y j x j x + j = 1 N n j η y j ) f(x) = sign(\sum _{j=1}^{N}n_j\eta y_jx_j \cdot x+ \sum _{j=1}^{N}n_j\eta y_j)