17.【進階】模型正則化--欠擬合與過擬合問題
阿新 • • 發佈:2019-01-02
#-*- coding:utf-8 -*-
#學習目標:以“披薩餅價格預測”為例,認識欠擬合和過擬合的問題
#假定只考慮披薩的尺寸和售價的關係,X為尺寸,y代表售價
X_train = [[6],[8],[10],[14],[18]]
y_train = [[7],[9],[13],[17.5],[18]]
#*************************************************************************************
#1.首先以一次線性迴歸函式進行預測
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X_train,y_train)
import numpy as np
#在X軸上從0~25均勻取樣100個數據點
xx = np.linspace(0,26,100)
xx = xx.reshape(xx.shape[0],1)
#以上述100個數據點為基準,預測迴歸直線
yy = lr.predict(xx)
#對迴歸預測到的直線進行作圖
import matplotlib.pyplot as plt
#scatter()功能:繪製散點圖,c='r'表示點的顏色為red,marker表示點的形狀'o'是預設的圓
# b---blue c---cyan g---green k----black m---magenta r---red w---white y---yellow
plt.scatter(X_train,y_train,c='r' ,marker='o')
#plot()功能:繪製折線圖,同樣可以設定顏色樣式,color屬性和linestyle屬性
plt1,= plt.plot(xx,yy)
#確定座標範圍:plt.axis([xmin, xmax, ymin, ymax])
#xlim(xmin, xmax)和ylim(ymin, ymax)來調整x,y座標範圍
plt.axis([0,25,0,25])
plt.xlabel('Diameter of Pizza')
plt.ylabel('Price of Pizza')
#注意legend的引數形式,第二個引數,字串放在()裡面,並且當只有一個引數時,要在結尾加上‘,’
plt.legend([plt1],('Degree=1',),'best')
plt.show()
#輸出線性迴歸模型在<訓練樣本>上的r2-score
print 'The R-squared value of lr is ',lr.score(X_train,y_train)
#The R-squared value of lr is 0.910001596424
#*************************************************************************************
#2.接下來以二次多項式迴歸函式進行預測
#將原特徵升高一個維度,以二次多項式迴歸模型對訓練樣本進行擬合
from sklearn.preprocessing import PolynomialFeatures
poly2 = PolynomialFeatures(degree=2) #映射出二次多項式特徵
X_train_poly2 = poly2.fit_transform(X_train)
lr_poly2 = LinearRegression()
lr_poly2.fit(X_train_poly2,y_train)
#重新映射回歸值,並繪製圖像
#因為訓練的模型lr_poly2針對的是2D的資料,所以此處要將xx也轉成2D的形式,才能預測對應的y值
xx_poly2 = poly2.transform(xx)
yy_poly2 = lr_poly2.predict(xx_poly2)
plt.scatter(X_train,y_train,c='r')
#獲取plot物件的方法:
#line, = plt.plot(x, y, '-')
#這裡的','不可以省略,不然在下面的legend中就會出錯。
plt1,= plt.plot(xx,yy)
#這邊繪圖時,傳的引數就是(xx,yy_poly2)了,不能是xx_poly2,點的座標形式是(x,y),x應為1維向量
plt2, = plt.plot(xx,yy_poly2)
plt.axis([0,25,0,25])
plt.xlabel('Diameter of Pizza')
plt.ylabel('Price of Pizza')
#legend:新增圖例(對直線的描述),第一個引數是要顯示的直線的列表[],第二個引數是每條直線的label,第三個引數是顯示的位置
plt.legend([plt1,plt2],('degree=1','degree=2'),'best')
plt.show()
#輸出二次多項式迴歸模型在<訓練樣本>上的r2-score
print 'The R-squared value of poly2 is ',lr_poly2.score(X_train_poly2,y_train)
#The R-squared value of poly2 is 0.98164216396
#*************************************************************************************
#3.最後再以四次多項式迴歸函式進行預測
poly4 = PolynomialFeatures(degree=4)
X_train_poly4 = poly4.fit_transform(X_train)
lr_poly4 = LinearRegression()
lr_poly4.fit(X_train_poly4,y_train)
#重新預測xx_poly4對應的迴歸值,並繪圖
xx_poly4 = poly4.transform(xx)
yy_poly4 = lr_poly4.predict(xx_poly4)
plt.scatter(X_train,y_train,c='r')
plt1, = plt.plot(xx,yy)
plt2, = plt.plot(xx,yy_poly2)
plt3, = plt.plot(xx,yy_poly4)
plt.axis([0,25,0,25])
plt.xlabel('Diameter of Pizza')
plt.ylabel('Price of Pizza')
plt.legend([plt1,plt2,plt3],('degree=1','degree=2','degree=4'),'best')
plt.show()
#輸出四次多項式迴歸模型在<訓練樣本>上的r2-score
print 'The R-squared value of poly4 is ',lr_poly4.score(X_train_poly4,y_train)
#The R-squared value of poly4 is 1.0
#總結:
#在實際生活中,第二個模型是最滿足真實情況的
#第一個模型的複雜度太低,導致了欠擬合
#第三個模型的複雜度太高,導致了過擬合
#第二個模型,相對來說,泛化能力更好一些
#為了兼顧模型的複雜度和預測準確性,我們採用了模型正則化方法,在下一講進行說明。