機器學習吳恩達-線性回歸筆記(1)
阿新 • • 發佈:2019-04-06
設置 裏的 更新 sha names value p s itl inf
回歸問題的思想(1)先找到損失函數,(2)求損失函數最小化後的參數
假設我們的數據是(m,n)有m行數據,n個特征(feature)
則我們預測函數為 :
寫成向量形式為(xo=1):
ps:因為存在截距項,這裏的X矩陣是n+1維的
定義代價函數CostFunction:
求 minJ(θ)
得到目標函數後,我們目標是想要代價函數盡可能小,利用凸優化知識,對J(θ)求偏導並帶入梯度下降公式中:
梯度下降請參考:https://www.cnblogs.com/pinard/p/5970503.html
向量式:i=(1,2,...n)
α是步長,決定更新快慢(過大可能會導致溢出)
到這裏就能求出所需要的參數的更新公式。
下面的例子是單變量的例子,用的隨機梯度,因為可以直接畫圖出,比較直觀。
1 import pandas as pd 2 import numpy as np 3 def CostFunction(x,y,theta): 4 med_var=np.power((x*theta)-y,2) 5 return np.sum(med_var)/(2*len(x)) 6 7 def Grandent(x,y,theta,alphl,maxcircle): 8 m = x.shape[0] 9 print(‘m=‘,m) 10 print(‘theta shape‘,theta.shape) 11 temp = np.matrix(np.zeros(theta.shape)) 12 print(‘temp shape‘,temp.shape) 13 cost = np.zeros(maxcircle) # 初始化一個ndarray,包含每次叠代的cost 14 for k in range(maxcircle): 15 # print(theta) 16 temp=theta-(alphl/m)*(x.T)*(x*theta-y)17 cost[k]=CostFunction(x,y,theta) 18 theta=temp 19 return theta,cost 20 21 22 data=pd.read_csv(‘ex1data1.txt‘,names=[‘feature‘,‘price‘]) #(97, 2) 23 data.insert(0,‘x0‘,1) 24 X_dataframe=data.drop([‘price‘],axis=1) 25 y_dataframe=data.price 26 X=np.matrix(X_dataframe.values) #專程矩陣格式 27 y=np.matrix(y_dataframe.values) 28 y=y.T 29 m,n=X_dataframe.shape 30 theta=np.zeros((n,1)) 31 alphl=0.01 #開始設置為0.1,會一直報溢出,導致梯度下降方法不收斂 32 maxcircle=1000 33 34 theta_fin,cost=Grandent(X,y,theta,alphl,maxcircle) 35 36 import matplotlib.pyplot as plt 37 38 fig,ax = plt.subplots() 39 40 ax.plot(np.arange(maxcircle), cost, ‘red‘) # np.arange()返回等差數組 41 ax.set_xlabel(‘Iterations‘) 42 ax.set_ylabel(‘Cost‘) 43 ax.set_title(‘Cost vs. num_Iterations‘) 44 45 #np.linspace()在指定的間隔內返回均勻間隔的數字 46 x = np.linspace(data.feature.min(), data.feature.max(), 100) # 橫坐標 47 f = theta_fin[0,0] + (theta_fin[1,0] * x) # 縱坐標 48 fig,ax = plt.subplots() 49 ax.plot(x, f, ‘r‘, label=‘Prediction‘) 50 ax.scatter(data[‘feature‘], data.price, label=‘Traning Data‘) 51 ax.legend(loc=2) # 2表示在左上角 52 ax.set_xlabel(‘feature‘) 53 ax.set_ylabel(‘price‘) 54 ax.set_title(‘Predicted Profit vs. Population Size‘) 55 plt.show()
下面是叠代1000,2000次的結果
機器學習吳恩達-線性回歸筆記(1)