1. 程式人生 > >防止過擬合的方法 預測鸞鳳花(sklearn)

防止過擬合的方法 預測鸞鳳花(sklearn)

ogr mod sep 類模型 for 包含 一輪 com stop

1. 防止過擬合的方法有哪些?

過擬合(overfitting)是指在模型參數擬合過程中的問題,由於訓練數據包含抽樣誤差,訓練時,復雜的模型將抽樣誤差也考慮在內,將抽樣誤差也進行了很好的擬合。

產生過擬合問題的原因大體有兩個:訓練樣本太少或者模型太復雜。

防止過擬合問題的方法:

(1)增加訓練數據。

考慮增加訓練樣本的數量

使用數據集估計數據分布參數,使用估計分布參數生成訓練樣本

使用數據增強

(2)減小模型的復雜度。

a.減少網絡的層數或者神經元數量。這個很好理解,介紹網絡的層數或者神經元的數量會使模型的擬合能力降低。
b.參數範數懲罰。參數範數懲罰通常采用L1和L2參數正則化(關於L1和L2的區別聯系請戳這裏)。
c.提前終止(Early stopping);
d.添加噪聲。添加噪聲可以在輸入、權值,網絡相應中添加。
e.結合多種模型。這種方法中使用不同的模型擬合不同的數據集,例如使用 Bagging,Boosting,Dropout、貝葉斯方法

技術分享圖片

而在深度學習中,通常解決的方法如下

Early stopping方法的具體做法是,在每一個Epoch結束時(一個Epoch集為對所有的訓練數據的一輪遍歷)計算validation data的accuracy,當accuracy不再提高時,就停止訓練。

獲取更多數據(從數據源頭獲取更多數據 根據當前數據集估計數據分布參數,使用該分布產生更多數據 數據增強(Data Augmentation)

正則化(直接將權值的大小加入到 Cost 裏,在訓練的時候限制權值變大)

dropout:在訓練時,每次隨機(如50%概率)忽略隱層的某些節點;

2. 使用邏輯回歸(Logistic Regression)對鳶尾花數據(多分類問題)進行預測,可以直接使用sklearn中的LR方法,並嘗試使用不同的參數,包括正則化的方法,正則項系數,求解優化器,以及將二分類模型轉化為多分類模型的方法。
獲取鳶尾花數據的方法:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)

print(__doc__)


# Code source: Ga?l Varoquaux
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, datasets

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2]  #
we only take the first two features. Y = iris.target h = .02 # step size in the mesh logreg = linear_model.LogisticRegression(C=1e5) # we create an instance of Neighbours Classifier and fit the data. logreg.fit(X, Y) # Plot the decision boundary. For that, we will assign a color to each # point in the mesh [x_min, x_max]x[y_min, y_max]. x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()]) # Put the result into a color plot Z = Z.reshape(xx.shape) plt.figure(1, figsize=(4, 3)) plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors=k, cmap=plt.cm.Paired) plt.xlabel(Sepal length) plt.ylabel(Sepal width) plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.xticks(()) plt.yticks(()) plt.show()

技術分享圖片

防止過擬合的方法 預測鸞鳳花(sklearn)