1. 程式人生 > >[ch04-01] 用最小二乘法解決線性迴歸問題

[ch04-01] 用最小二乘法解決線性迴歸問題

系列部落格,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點選star加星不要吝嗇,星越多筆者越努力。

4.1 最小二乘法

4.1.1 歷史

最小二乘法,也叫做最小平方法(Least Square),它通過最小化誤差的平方和尋找資料的最佳函式匹配。利用最小二乘法可以簡便地求得未知的資料,並使得這些求得的資料與實際資料之間誤差的平方和為最小。最小二乘法還可用於曲線擬合。其他一些優化問題也可通過最小化能量或最小二乘法來表達。

1801年,義大利天文學家朱賽普·皮亞齊發現了第一顆小行星穀神星。經過40天的跟蹤觀測後,由於穀神星執行至太陽背後,使得皮亞齊失去了穀神星的位置。隨後全世界的科學家利用皮亞齊的觀測資料開始尋找穀神星,但是根據大多數人計算的結果來尋找穀神星都沒有結果。時年24歲的高斯也計算了穀神星的軌道。奧地利天文學家海因裡希·奧爾伯斯根據高斯計算出來的軌道重新發現了穀神星。

高斯使用的最小二乘法的方法發表於1809年他的著作《天體運動論》中。法國科學家勒讓德於1806年獨立發明“最小二乘法”,但因不為世人所知而默默無聞。勒讓德曾與高斯為誰最早創立最小二乘法原理髮生爭執。

1829年,高斯提供了最小二乘法的優化效果強於其他方法的證明,因此被稱為高斯-馬爾可夫定理。

4.1.2 數學原理

線性迴歸試圖學得:

\[z(x_i)=w \cdot x_i+b \tag{1}\]

使得:

\[z(x_i) \simeq y_i \tag{2}\]

其中,\(x_i\)是樣本特徵值,\(y_i\)是樣本標籤值,\(z_i\)是模型預測值。

如何學得w和b呢?均方差(MSE - mean squared error)是迴歸任務中常用的手段:
\[ J = \sum_{i=1}^m(z(x_i)-y_i)^2 = \sum_{i=1}^m(y_i-wx_i-b)^2 \tag{3} \]

\(J\)稱為損失函式。實際上就是試圖找到一條直線,使所有樣本到直線上的殘差的平方和最小。

圖4-3 均方差函式的評估原理

圖4-3中,圓形點是樣本點,直線是當前的擬合結果。如左圖所示,我們是要計算樣本點到直線的垂直距離,需要再根據直線的斜率來求垂足然後再計算距離,這樣計算起來很慢;但實際上,在工程上我們通常使用的是右圖的方式,即樣本點到直線的豎直距離,因為這樣計算很方便,用一個減法就可以了。

假設我們計算出初步的結果是虛線所示,這條直線是否合適呢?我們來計算一下圖中每個點到這條直線的距離,把這些距離的值都加起來(都是正數,不存在互相抵消的問題)成為誤差。

因為上圖中的幾個點不在一條直線上,所以不能有一條直線能同時穿過它們。所以,我們只能想辦法不斷改變紅色直線的角度和位置,讓總體誤差最小(用於不可能是0),就意味著整體偏差最小,那麼最終的那條直線就是我們要的結果。

如果想讓誤差的值最小,通過對w和b求導,再令導數為0(到達最小極值),就是w和b的最優解。

推導過程如下:

\[ \begin{aligned} {\partial{J} \over \partial{w}} &={\partial{(\sum_{i=1}^m(y_i-wx_i-b)^2)} \over \partial{w}} \\ &= 2\sum_{i=1}^m(y_i-wx_i-b)(-x_i) \end{aligned} \tag{4} \]

令公式4為0:

\[ \sum_{i=1}^m(y_i-wx_i-b)x_i=0 \tag{5} \]

\[ \begin{aligned} {\partial{J} \over \partial{b}} &={\partial{(\sum_{i=1}^m(y_i-wx_i-b)^2)} \over \partial{b}} \\ &=2\sum_{i=1}^m(y_i-wx_i-b)(-1) \end{aligned} \tag{6} \]

令公式6為0:

\[ \sum_{i=1}^m(y_i-wx_i-b)=0 \tag{7} \]

由式7得到(假設有m個樣本):

\[ \sum_{i=1}^m b = m \cdot b = \sum_{i=1}^m{y_i} - w\sum_{i=1}^m{x_i} \tag{8} \]

兩邊除以m:

\[ b = {1 \over m}(\sum_{i=1}^m{y_i} - w\sum_{i=1}^m{x_i})=\bar y-w \bar x \tag{9} \]

其中:

\[ \bar y = {1 \over m}\sum_{i=1}^m y_i, \bar x={1 \over m}\sum_{i=1}^m x_i \tag{10} \]

將公式10代入公式5:

\[ \sum_{i=1}^m(y_i-wx_i-\bar y + w \bar x)x_i=0 \]

\[ \sum_{i=1}^m(x_i y_i-wx^2_i-x_i \bar y + w \bar x x_i)=0 \]

\[ \sum_{i=1}^m(x_iy_i-x_i \bar y)-w\sum_{i=1}^m(x^2_i - \bar x x_i) = 0 \]

\[ w = {\sum_{i=1}^m(x_iy_i-x_i \bar y) \over \sum_{i=1}^m(x^2_i - \bar x x_i)} \tag{11} \]

將公式10代入公式11:

\[ w={\sum_{i=1}^m (x_i \cdot y_i) - \sum_{i=1}^m x_i \cdot {1 \over m} \sum_{i=1}^m y_i \over \sum_{i=1}^m x^2_i - \sum_{i=1}^m x_i \cdot {1 \over m}\sum_{i=1}^m x_i} \tag{12} \]

分子分母都乘以m:

\[ w={m\sum_{i=1}^m x_i y_i - \sum_{i=1}^m x_i \sum_{i=1}^m y_i \over m\sum_{i=1}^m x^2_i - (\sum_{i=1}^m x_i)^2} \tag{13} \]

\[ b=\frac{1}{m}\sum_{i=1}^m(y_i-wx_i) \tag{14} \]

而事實上,式13有很多個變種,大家會在不同的文章裡看到不同版本,往往感到困惑,比如下面兩個公式也是正確的解:

\[ w = {\sum_{i=1}^m y_i(x_i-\bar x) \over \sum_{i=1}^m x^2_i - (\sum_{i=1}^m x_i)^2/m} \tag{15} \]

\[ w = {\sum_{i=1}^m x_i(y_i-\bar y) \over \sum_{i=1}^m x^2_i - \bar x \sum_{i=1}^m x_i} \tag{16} \]

以上兩個公式,如果把公式10代入,也應該可以得到和式13相同的答案,只不過需要一些運算技巧。比如,很多人不知道這個神奇的公式:

\[ \begin{aligned} \sum_{i=1}^m (x_i \bar y) &= \bar y \sum_{i=1}^m x_i =\frac{1}{m}(\sum_{i=1}^m y_i) (\sum_{i=1}^m x_i) \\ &=\frac{1}{m}(\sum_{i=1}^m x_i) (\sum_{i=1}^m y_i)= \bar x \sum_{i=1}^m y_i \\ &=\sum_{i=1}^m (y_i \bar x) \end{aligned} \tag{17} \]

4.1.3 程式碼實現

我們下面用Python程式碼來實現一下以上的計算過程:

計算w值

# 根據公式15
def method1(X,Y,m):
    x_mean = X.mean()
    p = sum(Y*(X-x_mean))
    q = sum(X*X) - sum(X)*sum(X)/m
    w = p/q
    return w

# 根據公式16
def method2(X,Y,m):
    x_mean = X.mean()
    y_mean = Y.mean()
    p = sum(X*(Y-y_mean))
    q = sum(X*X) - x_mean*sum(X)
    w = p/q
    return w

# 根據公式13
def method3(X,Y,m):
    p = m*sum(X*Y) - sum(X)*sum(Y)
    q = m*sum(X*X) - sum(X)*sum(X)
    w = p/q
    return w

由於有函式庫的幫助,我們不需要手動計算sum(), mean()這樣的基本函式。

計算b值

# 根據公式14
def calculate_b_1(X,Y,w,m):
    b = sum(Y-w*X)/m
    return b

# 根據公式9
def calculate_b_2(X,Y,w):
    b = Y.mean() - w * X.mean()
    return b

4.1.4 運算結果

用以上幾種方法,最後得出的結果都是一致的,可以起到交叉驗證的作用:

w1=2.056827, b1=2.965434
w2=2.056827, b2=2.965434
w3=2.056827, b3=2.965434

程式碼位置

ch04, Lev