1. 程式人生 > >「05」迴歸的誘惑:一文讀懂線性迴歸

「05」迴歸的誘惑:一文讀懂線性迴歸

前言

從這一篇文章開始,就正式進入「美團」演算法工程師帶你入門機器學習系列的正文了,之前的幾篇算是導讀和預熱,想必大家看的並不過癮。從這裡開始,我們將會以線性迴歸為起點,貫通迴歸方法在機器學習演算法中所扮演的角色、具有的功能和使用的方法。

說起迴歸,它是我們在高中時就接觸過的內容。具體的,迴歸(Regression)是指研究一組隨機變數(Y1 ,Y2 ,…,Yi)和另一組隨機變數(X1,X2,…,Xk)之間關係的統計分析方法,又稱多重回歸分析。通常Y1,Y2,…,Yi是因變數,X1、X2,…,Xk是自變數。因變數,就是指被影響、決定的變數,本身不參與運算,而自變數則是指自身發生變化、改變並參與運算,最終影響因變數的變數。這些內容都是高中學習過的基礎,這裡僅僅做個回顧,不深入複習。

現在,讓我們先拋開機器學習、演算法、模型這類名詞,從最簡單的線性迴歸來看看,到底什麼是迴歸(的誘惑)

 

 

線性迴歸是什麼?

我們前面提到過,迴歸是計算因變數和自變數之間統計關係的一種方法。而線性迴歸可以理解為學習變數之間線性關係的方法。作為一切迴歸的基礎,它已經存在了時間長了,是無數教科書的主題。

雖然看起來,線性迴歸與一些更現代的統計學習方法,比如支援向量機相比,有些過於簡單。但在我們後續章節介紹的方法中,線性迴歸仍是一種非常有用的統計學習方法,它可以被用來簡單的測試資料和分析資料。 基於這個出發點,許多有趣的機器學習方法可以被看作是概線性迴歸的擴充套件。在談及具體的迴歸方法之前,讓我們先來看看線性迴歸為什麼叫回歸。

 

在19世紀的英國,有一位著名的生物學家高爾頓,在研究父母和孩子身高的遺傳關係時,發現了一個直線方程,通過這個方程,他幾乎準確地擬合了被調查父母的平均身高 x 和 子女平均身高 y 之前的關係:

這在當時可是一件不得了的事情,那這個方程是什麼意思呢?它代表父母身高每增加1個單位, 其成年子女的平均身高只增加0.516個單位,反映了一種“衰退”效應(“迴歸”到正常人平均身高)。雖然之後的x與 y變數之間並不總是具有“衰退”(迴歸)關係,但是為了紀念高爾頓這位偉大的統計學家,“線性迴歸”這一名稱就保留了下來。

 

一元線性迴歸

迴歸最簡單的情形是一元線性迴歸,也即由彼此間存線上性關係的一個自變數和一個因變數組成,方程寫作Y=a+bX+ε(X是自變數,Y是因變數,ε是隨機誤差)我們先來看一個例子:

假設有如下資料點,橫軸代表一個產品的廣告費X,縱軸代表產品的銷售額Y。迴歸可以看做是用Y=a+bX+ε這條直線去擬合這些資料點,也就是儘量使這些資料點與直線的距離之和(也叫作平方誤差和)最小。

通過最小二乘法或梯度下降(後面會講這兩個方法),我們得到如下方程:

那麼,當一個新的廣告計劃出現,我們通過已知的X(廣告投入)代入方程,就可以儘可能準確的算出預期的產品銷售額了~當然,反過來,使用一個預期的產品銷售額,也可以反推出我們需要投入的廣告費用。

 

多元線性迴歸

當自變數大於1的時候,比如X=(x1, x2),我們稱它為多元線性迴歸,寫作

其中,y(x)也就是我們說的因變數,x為自變數,但是0次項和1次項的係數(a, b)被一個向量w所代替。這裡的w是一個簡單的矩陣線性乘法問題,對應了以下的向量

如果對於線性代數問題還有不理解的小夥伴,可以自行查閱《線性代數同濟版》

來看一個例子:給定⼀個有關房屋的資料集,其中每棟房屋的相關資料包括⾯積(平⽅⽶)、房齡(年)和價格(元)。假設我們想使⽤任意⼀棟房屋的⾯積(設x1)和房齡(設x2)來估算它的真實價格(設y)。那麼x1 和x2 即每棟房屋的特徵(feature),y 為標籤(label)或真實值(ground truth)。線上性迴歸模型中,房屋估計價格(設 ​)的表示式為

其中w1,w2 是權重(weight),通常用向量

來表示,b 是偏差(bias),也就是前面一元迴歸裡我們用到的b。這⾥的權重和偏差都叫做線性迴歸模型的引數(parameter)

 

 

線性迴歸的假設

線性迴歸作為被嚴謹證明過的數學方法,有7個必備的假設前提。理論上,必須滿足這7個嚴格的假設,我們才能確保線性迴歸學習到的公式/方程是統計意義成立的。但在實際使用時,我們只需要滿足前3個最重要的假設即可(其他的一般都預設成立)。以後在學習其他演算法前,我們也需要了解類似的假設。

關於線性迴歸,最主要的3條假設如下

  • 隨機誤差的均值為0
  • 隨機誤差的方差為σ^2
  • σ^2與X的值無關

這裡規定的σ^2並不是一個具體數值,只需要大於0即可。主要是為了說明隨機誤差的方差是存在的,方差(二階中心矩)不存在,比如無限大,則這個模型就是病態的,這裡涉及到高等代數,感興趣的同學可以自行深入。

 

若進一步假定隨機誤差遵從正態分佈,就叫做正態線性模型。若有k個自變數和1個因變數,則因變數的值分為兩部分:一部分由自變數影響,即表示為它的函式,函式形式已知且含有未知引數;另一部分由其他的未考慮因素和隨機性影響,即隨機誤差。

一般來說,隨機誤差在引數學習中起到的作用有限,但我們在真正使用模型時還是需要先看看資料是否滿足線性迴歸的前提,否則容易對模型引數產生擬合問題。

 

線性迴歸的本質

一般來說,迴歸分析是通過規定因變數和自變數來確定變數之間的因果關係,建立迴歸模型,並根據實測資料來求解模型的各個引數,然後評價迴歸模型是否能夠很好的擬合實測資料,如果能夠很好的擬合,則可以根據自變數作進一步預測,比如我們提到的廣告費用與產品銷售額的關係。

當函式為引數未知的線性函式時,稱為線性迴歸分析模型;當函式為引數未知的非線性函式時,稱為非線性迴歸分析模型。當自變數個數大於1時稱為多元迴歸,當因變數個數大於1時稱為多元迴歸。

當X和Y只有一個維度(一元迴歸),且因變數和自變數的關係是線性關係,線性圖表示就是一條直線,而多維度(多元迴歸)學習到的引數方程,體現到空間中就是一個超平面。

 

要注意的是,我們平時可能會把擬合與迴歸弄混淆,但其實二者有本質區別。你可以把現實世界的資料看做“表象”,把你擬合出來的那個模型看做“本質”。由表象到本質的過程就是“迴歸”。而擬合是一種得到函式的手段,常和數值領域的“插值”放在一起,也就是得到迴歸函式的手段。

當迴歸函式未知時,我們可以通過擬合這種手段算出迴歸函式,求這個迴歸函式的問題叫做迴歸問題。一個是問題的類別,一個是解決方法的類別,迴歸和擬合的差別就在這裡。

數學理論的世界是精確的,譬如在廣告-銷量方程中,你代入x=0就能得到唯一的 y=7.1884,但這個y並不是我們真實觀測到的,而是估計值。現實世界中的資料就像散點圖,我們只能儘可能地在雜亂中尋找規律,很難100%的完美擬合一條直線出來。用數學的模型去擬合現實的資料,這就是統計。統計不像數學那麼精確,統計的世界不是非黑即白的,它有“灰色地帶”,但是統計會將理論與實際間的差別表示出來,也就是“誤差”。

我們在前面學習到的公式,就是線性迴歸作為一種學習演算法的本質,即模型

通過求解引數w,我們知道了自變數和因變數之間的線性關係,即我們擬合的直線

這個直線就是我們學習到的模型,儘可能地學習到一個完美的W,這就是線性迴歸的本質和作用,也是一切機器學習學習演算法的本質——函式(引數)學習。

 

 

線性迴歸的引數學習

現在我們來看一看如何學習到這個完美的W。線性迴歸的目標可以理解為減少殘差平方和,回到總體均值。在探究線性迴歸的學習方法之前,我們先定義如下表示,方便後續解釋:

  ​表示輸入變數(自變數),第一部分例子中的X。

  ​表示輸出變數(因變數),第一部分例子中的Y。

一對 ​表示一組訓練樣本。

m個訓練樣本 ​稱為訓練集。

 

回到上面一元迴歸的例子來,既然是用直線擬合散點,為什麼最終得到的直線是y = 0.0512x + 7.1884,而不是下圖中的橙色的y = 0.0624x + 5呢?畢竟這兩條線看起來都可以擬合這些資料。

我們很容發現,資料不是真的落在一條直線上,而是分佈在直線周圍,所以我們要找到一個評判標準,用於評價哪條直線才是最“合適”的。這就是我們以後時常會見到的東西——損失函式

 

在這裡,我們使用的損失函式叫做殘差,也就是真實值和預測值間的差值(也可以理解為距離),用公式表示是:

對於某個廣告投入 ​ ,我們有對應的實際銷售量 ​和預測出來的銷售量 ​(通過將 ​代入直線公式計算得到),計算  ​ 的值,再將其平方(為了消除負號),將所有的  ​相加,就能量化出擬合的直線和實際之間的誤差。

這裡使用的均方誤差有非常好的幾何意義,它對應了常用的歐幾里得距離或簡稱"歐氏距離" (Euclidean distance),也就是圖裡的資料點和直線之間的距離。基於均方誤差最小化來進行模型求解的方法,稱為“最小二乘法” (least square method)。線上性迴歸中,最小二乘法就是試圖找到一條直線,使所有樣本到直線上的歐氏距離之和最小。

 

最小二乘法

求解方程引數,使

最小化的過程,稱為線性迴歸模型的最小二乘"引數估計" (parameter estimation)。我們可將E(w,b)分別對W 和b求導,得到

我們令上面兩個式子的導數為零可得到W和b最優解的閉式(closed-form) 解(也就是可以直接通過公式代入算出來的解)

其中 ​,為X的均值

 

當X和Y為多元迴歸時,我們也有多元情況下的最小二乘法,寫作

同樣地,另這個偏導數為0,我們可以得到

這個東西也叫作正規方程,因為它很正規 (。≖ˇェˇ≖。)。具體的推導和運算涉及到矩陣的逆/偽逆,比單變數情形要複雜,這裡不深入展開,感興趣的同學可以自行翻閱《矩陣論》。這裡我們只需要知道兩點:

1. 現實任務中XTX 往往不是滿秩矩陣.例如在許多工中我們會遇到大量的變數,其數目甚至超過樣例數,導致X 的列數多於行數,XTX 顯然不滿秩。此時可得到出無限個解, 它們都能使均方誤差最小化

2. 我們的輸入資料X(自變數),可以寫作矩陣形式,矩陣的橫軸代表每個資料的維度(比如房屋的價格,位置,年齡),縱軸代表每個房屋

學過線性代數的同學應該知道,當行、列很多時, 這個矩陣的任何運算都需要很大的計算量。尤其是輸入變數的維度較大時(橫軸的n比較大),該演算法的計算複雜度成指數級增加。

因此,正規方程的解法在真實場景中很少見,我們有另一種叫做梯度下降的方法,通過損失一定精度,來近似逼近這個最優解。對於梯度方法,這裡只做一個簡單介紹,之後會有專門的一期文章來聊聊機器學習中的優化方法。

 

梯度下降

這裡引用優化教材中的一張圖,這裡我們把引數W寫作 ​,兩者其實是一種東西。現在我們來思考:既然代價函式是關於 ​的函式,有沒有辦法把求解過程加速或者拆解呢?

答案是有。

上圖中的藍色區域可以理解為誤差函式最小的點,也就是我們要找的引數值,因此,找到該點對應的 ​,即完成了任務。如何找到最低點位置對應的引數呢?答案是對代價函式(也就是我們的誤差)求偏導數

我們用大學學過的微積分方法做一個拆解,可以得到:

這就是關於變數的偏導數。要注意的是,這裡的h其實就是我們的y

假設我們的函式只有兩個維度(二元迴歸),給定

就是我們要求的引數,誤差函式對第一個元求偏導的結果:

誤差函式對第二個元求偏導的結果:

求得的結果怎麼使用?我們對  ​ 求偏導數的意義是得到這一點上的切線的斜率,它將給我們一個向最小值移動的方向。因此, ​減去偏導數,就等於 ​向最小值的方向移動了一步。這一步的大小由一個引數決定,也稱作學習率。用公式表達如下:

這就是機器學習中大名鼎鼎的的梯度下降。這個公式為什麼這麼寫,有什麼意義,之後在優化方法的文章中會寫。對底層原理感興趣的話,大家可以以前去看看MIT的微積分公開課(可汗學院、網易都有),以及Boyd所寫的《凸優化》,到時候看部落格就會非常通透。

這裡舉一個我在知乎上看到的例子,非常具體的解釋了線性迴歸求解的過程

  1. 初始化一個模型,例如 h = 2 + 3x,也就是說,我們的初始引數是 
  2. 給定一個樣本對,例如(2,4),代入模型中求得預測值,即 h = 2 + 3*2 = 8
  3. 代入代價函式公式中,求代價值,即 J = 1/2 * (8-4) ^ 2 = 8
  4. 代入偏導數公式中求兩個變數的偏導數,即 

假設我們的學習率是0.1,那麼代入梯度下降公式得到 

我們得到了新的引數,即 

所以新的模型是:h = 1.6 + 2.2x,新的預測值是h = 1.6 + 2.2*2 = 6,再次計算代價函式的值:J = 1/2 * (6-4) ^ 2 = 2

比較新的模型得到的代價值2,比老模型得到的代價值8減少了6,代價越小說明我們的模型與訓練集匹配的越好,所以通過不斷的梯度下降,我們可以得到最適合訓練資料的模型h,也就是前面提到的那條直線方程。

 

線性迴歸的侷限

線性迴歸簡單、直觀、迅速,但也有不少侷限,這也是之後更多高階演算法的出現原因,它們一定程度上解決了線性迴歸無法解決的問題。線性迴歸的侷限可以歸納以下幾點:

  • 需要嚴格的假設。
  • 只能用於變數間存在簡單線性關係的資料。
  • 當資料量、資料維度大時,計算量會指數級增加。
  • 需處理異常值,對異常值很敏感,對輸入資料差異也很敏感。
  • 線性迴歸存在共線性,自相關,異方差等問題。

 

結語

到這裡,線性迴歸的文章就告一段落了。在這一篇文章中,我們通過線性迴歸,簡單瞭解了機器學習的方式、概念和方法,但是對於更加具體的定義,比如模型、損失函式和監督學習還沒講到,這將是我之後文章的主題。下一期文章,我們將基於線性迴歸,來深入探討回歸的更多使用方法。

線性迴歸的Python程式碼和案例實戰在這一篇:「06」迴歸的誘惑:一文讀懂線性迴歸(Python實戰篇) ,程式碼不多,建議大家可以自己敲一敲。

PS. 如果大家閱讀其中的數學部分有些吃力的話,可以到我的這篇文章中找對應的知識點複習:「04」機器學習、深度學習需要哪些數學知識?

 

 

課後習題

給定每月電話諮詢次數(X)和每月實際銷量,線性迴歸是否可以把圖中的資料點分為不同的兩個部分?如果可以,應該怎麼分?如果不可以,又是為什麼?

參考文獻

  1. 《機器學習》周志華
  2. 《動手學深度學習》MXNet Community
  3.  An Introduction toStatistical Learning with Applications in R
  4.  知乎:機器學習之線性迴歸
  5.  知乎:線性迴歸詳解