1. 程式人生 > >防止過擬合的方法 預測鸞鳳花(sklearn)

防止過擬合的方法 預測鸞鳳花(sklearn)

1. 防止過擬合的方法有哪些?

過擬合(overfitting)是指在模型引數擬合過程中的問題,由於訓練資料包含抽樣誤差,訓練時,複雜的模型將抽樣誤差也考慮在內,將抽樣誤差也進行了很好的擬合。

產生過擬合問題的原因大體有兩個:訓練樣本太少或者模型太複雜。 

防止過擬合問題的方法:

(1)增加訓練資料。

考慮增加訓練樣本的數量

使用資料集估計資料分佈引數,使用估計分佈引數生成訓練樣本

使用資料增強

(2)減小模型的複雜度。

a.減少網路的層數或者神經元數量。這個很好理解,介紹網路的層數或者神經元的數量會使模型的擬合能力降低。
b.引數範數懲罰。引數範數懲罰通常採用L1和L2引數正則化(關於L1和L2的區別聯絡請戳這裡)。
c.提前終止(Early stopping);
d.新增噪聲。新增噪聲可以在輸入、權值,網路相應中新增。
e.結合多種模型。這種方法中使用不同的模型擬合不同的資料集,例如使用 Bagging,Boosting,Dropout、貝葉斯方法

 

而在深度學習中,通常解決的方法如下

Early stopping方法的具體做法是,在每一個Epoch結束時(一個Epoch集為對所有的訓練資料的一輪遍歷)計算validation data的accuracy,當accuracy不再提高時,就停止訓練。

獲取更多資料(從資料來源頭獲取更多資料      根據當前資料集估計資料分佈引數,使用該分佈產生更多資料    資料增強(Data Augmentation)

正則化(直接將權值的大小加入到 Cost 裡,在訓練的時候限制權值變大)

dropout:在訓練時,每次隨機(如50%概率)忽略隱層的某些節點;

 

 

2. 使用邏輯迴歸(Logistic Regression)對鳶尾花資料(多分類問題)進行預測,可以直接使用sklearn中的LR方法,並嘗試使用不同的引數,包括正則化的方法,正則項係數,求解優化器,以及將二分類模型轉化為多分類模型的方法。
獲取鳶尾花資料的方法:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)

 

print(__doc__)


# Code source: Gaël Varoquaux
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause import numpy as np import matplotlib.pyplot as plt from sklearn import linear_model, datasets # import some data to play with iris = datasets.load_iris() X = iris.data[:, :2] # we only take the first two features. Y = iris.target h = .02 # step size in the mesh logreg = linear_model.LogisticRegression(C=1e5) # we create an instance of Neighbours Classifier and fit the data. logreg.fit(X, Y) # Plot the decision boundary. For that, we will assign a color to each # point in the mesh [x_min, x_max]x[y_min, y_max]. x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()]) # Put the result into a color plot Z = Z.reshape(xx.shape) plt.figure(1, figsize=(4, 3)) plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors='k', cmap=plt.cm.Paired) plt.xlabel('Sepal length') plt.ylabel('Sepal width') plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.xticks(()) plt.yticks(()) plt.show()