1. 程式人生 > >機器學習十大演算法--迴歸演算法(批量梯度下降)

機器學習十大演算法--迴歸演算法(批量梯度下降)

機器學習演算法分為有監督學習和無監督學習,迴歸演算法屬於監督學習一類。本人小白一個,想學習機器學習,以後有新的學習結果,希望都可以寫在部落格上面,以此來監督自己的學習,如有不對的地方,還希望大家多多諒解,指出錯誤。

迴歸演算法主要是對已給的資料通過調整引數的手段來使計算結果接近於標籤值,但是資料點並非完全都恰好在一條直線上,因此需要有一個來評定調整後的引數的標準,一般使用類似方差的損失函式來評定調整後的引數的好壞,損失函式值越小,說明擬合度越好,越大的話則相反。

評定調整後引數的好壞: 

假如該批資料有3個變數x1、x2、x3,一個標籤值y,符合一次方程。

對應的引數值是a1、a2、a3

列出函式式 y = a1*x1+a2*x2+a3*x3;

一般我們需要新增一個偏移引數a0來使函式式可以上下移動,可以更好的使函式接近真實情況,則 y =a0+ a1*x1+a2*x2+a3*x3;。

現在如何去評定引數值 a1、a2、a3的好壞,這時就需要去判斷這些資料點在進行a0+ a1*x1+a2*x2+a3*x3運算後的y與標籤Y的差值,此時就形成了[(y0,Y0),(y1,Y1),(y2,Y2)......(yn,Yn),現在構造損失函式 loss = (y0-Y0)^2+(y1-Y1)^2+(y2-Y2)^2+......+(yn-Yn)^2;


前面說到,損失函式的結果值越小越好,那麼現在就想辦法使損失函式的值最小。

如何調整引數:批量梯度下降是其中的一種方法,該方法類似高中函式中的知識點,即切點方向是函式上升的最快方向,反方向是函式下降最快的方向。現在loss函式對引數a0、a1、a2、a3分別求導,找出他們使loss下降最快的方向。


方向找到了,現在討論如何向這個方向移動,一次移動多少,如何調整引數。

現在又有一個叫做步長(學習率)的東西,步長太大會越過最優值,步長太小到達最優值又需要太多的步數。

類似如圖的形式:

步長一般需要根據經驗來設定,一般設定為0.01;

調整引數的公式為ai = ai -  ai*步長,以這種形式分別對a0,a1,a2,a3...an進行調整。

可以設定一個標準使兩次調整後的損失在一定範圍內時停止調整。

python程式碼實現:

#y = 1*x1+2*x2+3*x3+4*x4;
xi = [(1,2,3,4),(1,2,4,5),(1,1,3,4),(1,2,2,2),(1,2,7
,5),(1,7,8,8),(1,22,42,51),(1,21,41,25)]; yi = [30,37,28,19,46,71,375,266]; epsilon = 0.000000001 alpha = 0.00001 a0 = 0 a1 = 0 a2 = 0 a3 = 0 error0 = 0; error1 = 0; # y = a0*xi[0]+ a1*xi[1]+ a2*xi[2]+ a3*xi[3]; ctn=0 # dist = (y - (a0*xi[0]+ a1*xi[1]+ a2*xi[2]+ a3*xi[3]))**2 while True: ctn += 1; for i in range(xi.__len__()): deff = (a0*xi[i][0]+ a1*xi[i][1]+a2*xi[i][2]+ a3*xi[i][3])-yi[i] a0 -= alpha * deff * xi[i][0]; a1 -= alpha * deff * xi[i][1]; a2 -= alpha * deff * xi[i][2]; a3 -= alpha * deff * xi[i][3]; error0 = 0; for i in range(xi.__len__()): error0 += ((a0*xi[i][0]+ a1*xi[i][1]+ a2*xi[i][2]+ a3*xi[i][3])-yi[i])**2; if abs(error1 - error0) < epsilon : break; else : error1 = error0; print "ct" print "ctn: %f, a0: %f,a1: %f, a2:%f a3: %f error0: %f"%(ctn,a0,a1,a2,a3,error0)


結果:

ctn: 143878.000000, a0: 0.998394,a1: 2.000357, a2:2.999860 a3: 4.000003 error0: 0.000012
ct
ctn: 143879.000000, a0: 0.998394,a1: 2.000357, a2:2.999860 a3: 4.000003 error0: 0.000012
ct
ctn: 143880.000000, a0: 0.998394,a1: 2.000357, a2:2.999860 a3: 4.000003 error0: 0.000012

Process finished with exit code 0

a0、a1、a2、a3最後接近下面函式中的引數值。

y = 1*x1+2*x2+3*x3+4*x4;