1. 程式人生 > >Python scikit-learn機器學習工具包學習筆記:cross_validation模組

Python scikit-learn機器學習工具包學習筆記:cross_validation模組

sklearn.cross_validation模組的作用顧名思義就是做cross validation的。

cross validation大概的意思是:對於原始資料我們要將其一部分分為train data,一部分分為test data。train data用於訓練,test data用於測試準確率。在test data上測試的結果叫做validation error。將一個演算法作用於一個原始資料,我們不可能只做出隨機的劃分一次train和test data,然後得到一個validation error,就作為衡量這個演算法好壞的標準。因為這樣存在偶然性。我們必須好多次的隨機的劃分train data和test data,分別在其上面算出各自的validation error。這樣就有一組validation error,根據這一組validation error,就可以較好的準確的衡量演算法的好壞。

cross validation是在資料量有限的情況下的非常好的一個evaluate performance的方法。

而對原始資料劃分出train data和test data的方法有很多種,這也就造成了cross validation的方法有很多種。

sklearn中的cross validation模組,最主要的函式是如下函式:

sklearn.cross_validation.cross_val_score。他的呼叫形式是scores = cross_validation.cross_val_score(clf, raw data, raw target, cv=5, score_func=None)

引數解釋:

clf是不同的分類器,可以是任何的分類器。比如支援向量機分類器。clf =svm.SVC(kernel='linear', C=1)

cv引數就是代表不同的cross validation的方法了。如果cv是一個int數字的話,並且如果提供了raw target引數,那麼就代表使用StratifiedKFold分類方式,如果沒有提供raw target引數,那麼就代表使用KFold分類方式。

cross_val_score函式的返回值就是對於每次不同的的劃分raw data時,在test data上得到的分類的準確率。至於準確率的演算法可以通過score_func引數指定,如果不指定的話,是用clf預設自帶的準確率演算法。

還有其他的一些引數不是很重要。

cross_val_score具體使用例子見下:

>>> clf = svm.SVC(kernel='linear', C=1)

>>> scores = cross_validation.cross_val_score(

...    clf, raw data, raw target, cv=5)

...

>>> scores                                            

array([ 1.  ...,  0.96...,  0.9 ...,  0.96...,  1.        ])

除了剛剛提到的KFold以及StratifiedKFold這兩種對raw data進行劃分的方法之外,還有其他很多種劃分方法。但是其他的劃分方法呼叫起來和前兩個稍有不同(但是都是一樣的),下面以ShuffleSplit方法為例說明:

>>> n_samples = raw_data.shape[0]

>>> cv = cross_validation.ShuffleSplit(n_samples, n_iter=3,

...     test_size=0.3, random_state=0)



>>> cross_validation.cross_val_score(clf, raw data, raw target, cv=cv)

...                                                     

array([ 0.97...,  0.97...,  1.        ])

還有的其他劃分方法如下:

cross_validation.Bootstrap

cross_validation.LeaveOneLabelOut

cross_validation.LeaveOneOut

cross_validation.LeavePLabelOut

cross_validation.LeavePOut

cross_validation.StratifiedShuffleSplit

他們的呼叫方法和ShuffleSplit是一樣的,但是各自有各自的引數。至於這些方法具體的意義,見machine learning教材。

還有一個比較有用的函式是train_test_split

功能:從樣本中隨機的按比例選取train data和test data。呼叫形式為:

X_train, X_test, y_train, y_test = cross_validation.train_test_split(train_data, train_target, test_size=0.4, random_state=0)

test_size是樣本佔比。如果是整數的話就是樣本的數量。random_state是隨機數的種子。不同的種子會造成不同的隨機取樣結果。相同的種子取樣結果相同。