1. 程式人生 > >多項式迴歸案例(附資料集下載地址)

多項式迴歸案例(附資料集下載地址)

當我們完成了資料的預處理環節後,我們可以先對資料進行視覺化,根據影象可以初步的判斷我們的模型應該是怎麼樣的,如何更好地擬合,請看下面的例子:
資料集:

Position Level Salary
Business Analyst 1 45000
Junior Consultant 2 50000
Senior Consultant 3 60000
Manager 4 80000
Country Manager 5 110000
Region Manager 6 150000
Partner 7 200000
Senior Partner 8 300000
C-level 9 500000
CEO 10 1000000
#首先還是匯入必要的庫
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
import statsmodels.formula.api as sm
#載入資料,由於資料樣本少,不做分片
dataset = pd.read_csv('Position_Salaries.csv')
# X = dataset.iloc[:,1].values   #shape -- (10,)是個向量,可是我們需要進入訓練的變數需要是一個矩陣
X = dataset.iloc[:,1:2].values #shape -- (10, 1) 矩陣 Y = dataset.iloc[:,2].values #shape -- (10,)是個向量
#視覺化資料
plt.figure(figsize=(10,14))
plt.subplot(211)
plt.scatter(X,Y)
plt.savefig('scatter.png')

資料描點
資料描點完成了,根據影象容易得出一元方程擬合效果不會太好。
一元線性迴歸
實際也正如我們料想的一致

觀察圖一,我們可以選用多項式來解決這個迴歸問題

  1. 對自變數進行矩陣轉化,轉化為有不同次數的矩陣
# 對X進行多次項處理
Poly = PolynomialFeatures(degree=2)  #引數degree是限定生成的X矩陣的最高次數
X_poly = Poly.fit_transform(X)

輸出X_poly結果如下:

#這個操作自動添加了常數項的係數(第一列)第二列是一次項,第二列是二次項
[[  1.   1.   1.]   
 [  1.   2.   4.]
 [  1.   3.   9.]
 [  1.   4.  16.]
 [  1.   5.  25.]
 [  1.   6.  36.]
 [  1.   7.  49.]
 [  1.   8.  64.]
 [  1.   9.  81.]
 [  1.  10. 100.]]
poly_reg = sm.OLS(endog=Y,exog=X_poly).fit()
Y_pre2 = poly_reg.predict(X_poly)
plt.plot(X,poly_reg.predict(X_poly) ,color = 'black',label = 'poly_degree=2')
plt.legend()
plt.savefig('lin&poly.png')

這裡寫圖片描述
顯然擬合也不是特別好。可以通過提高多項式次數來達到更好的擬合度,小心過度擬合問題
Poly = PolynomialFeatures(degree=3)  #引數degree是限定生成的X矩陣的最高次數
X_poly = Poly.fit_transform(X)
poly_reg = sm.OLS(endog=Y,exog=X_poly).fit()
Y_pre2 = poly_reg.predict(X_poly)
plt.plot(X,poly_reg.predict(X_poly) ,color = 'green',label = 'poly_degree=3')
plt.legend()  #提高到3次基本擬合效果很好了
plt.savefig('lin&poly3.png')

這裡寫圖片描述
提高到3次基本擬合效果很好了

最後對影象進行優化處理,上述影象由於自變數的間距相對較大,影象不夠平滑。我們可以有如下操作:

X_grid = np.arange(min(X),max(X),0.1)
X_grid = X_grid.reshape(len(X_grid),1)

plt.plot(X_grid, poly_reg.predict(Poly.fit_transform(X_grid)) ,color = 'green',label = 'poly_degree=3')
plt.legend()

這裡寫圖片描述

迴歸器資訊:


迴歸器資訊