1. 程式人生 > >使用scikit-learn進行機器學習的簡介(教程1)

使用scikit-learn進行機器學習的簡介(教程1)

一、機器學習:問題設定

通常,一個學習問題是通過分析一些資料樣本來嘗試預測未知資料的屬性。如果每一個樣本不僅僅是一個單獨的數字,比如一個多維的例項(multivariate data),也就是說有著多個屬性特徵

我們可以把學習問題分成如下的幾個大類:

  • (1)有監督學習
    資料帶有我們要預測的屬性。這種問題主要有如下幾種:

    • ①分類
      樣例屬於兩類或多類,我們想要從已經帶有標籤的資料學習以預測未帶標籤的資料。識別手寫數字就是一個分類問題,這個問題的主要目標就是把每一個輸出指派到一個有限的類別中的一類。另一種思路去思考分類問題,其實分類問題是有監督學習中的離散形式問題。每一個都有一個有限的分類。對於樣例提供的多個標籤,我們要做的就是把未知類別的資料劃分到其中的一種。

    • ②迴歸
      去過預期的輸出包含連續的變數,那麼這樣的任務叫做迴歸。根據三文魚的年紀和中聯預測其長度就是一個迴歸樣例。

  • (2)無監督學習
    訓練資料包含不帶有目標值的輸入向量x。對於這些問題,目標就是根據資料發現樣本中相似的群組——聚類。或者在輸入空間中判定資料的分佈——密度估計,或者把資料從高維空間轉換到低維空間以用於視覺化

訓練集和測試集
機器學習是學習一些資料集的特徵屬性並將其應用於新的資料。這就是為什麼在機器學習用來評估演算法時一般把手中的資料分成兩部分。一部分我們稱之為訓練集,用以學習資料的特徵屬性。一部分我們稱之為測試集,用以檢驗學習到的特徵屬性。

二、載入一個樣本資料集

scikit-learn帶有一些標準資料集。比如用來分類的iris資料集、digits資料集;用來回歸的boston house price 資料集。

接下來,我們我們從shell開啟一個Python直譯器並載入iris和digits兩個資料集。【譯註:一些程式碼慣例就不寫了,提示符>>>之類的學過Python的都懂】

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">$ python  
  2. >>><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from
     sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import datasets  
  3. >>>iris = datasets.load_iris()  
  4. >>>digits = datasets.load_digits()</span></span></code>  
$ python
>>>from sklearn import datasets
>>>iris = datasets.load_iris()
>>>digits = datasets.load_digits()

一個數據集是一個包含資料所有元資料的類字典物件。這個資料儲存在 '.data'成員變數中,是一個$n*n$的陣列,行表示樣例,列表示特徵。在有監督學習問題中,一個或多個響應變數(Y)儲存在‘.target’成員變數中。不同資料集的更多細節可以在dedicated section中找到。

例如,對於digits資料集,digits.data可以訪問得到用來對數字進行分類的特徵:

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">>>>print(digits.data)    
  2. [[  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5. ...,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
  3.  [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
  4.  [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">16.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
  5.  ...,  
  6.  [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1. ...,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">6.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
  7.  [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
  8.  [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]]</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  
>>>print(digits.data)  
[[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ...,
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]]

digits.target 就是數字資料集對應的真實數字值。也就是我們的程式要學習的。

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">>>>digits.target  
  2. array([<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2, ..., <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8])</span></span></span></span></span></span></code>  
>>>digits.target
array([0, 1, 2, ..., 8, 9, 8])

資料陣列的形狀
儘管原始資料也許有不同的形狀,但實際使用的資料通常是一個二維陣列(n個樣例,n個特徵)。對於數字資料集,每一個原始的樣例是一張(8 x 8)的圖片,也能被使用:

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">>>>digits.images[<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0]  
  2. array([[  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">13.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
  3.        [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">13.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">15.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">15.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
  4.        [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">15.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">11.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
  5.        [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
  6.        [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
  7.        [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">11.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">7.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
  8.        [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">14.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
  9.        [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">6.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">13.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]])</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  
>>>digits.images[0]
array([[  0.,   0.,   5.,  13.,   9.,   1.,   0.,   0.],
       [  0.,   0.,  13.,  15.,  10.,  15.,   5.,   0.],
       [  0.,   3.,  15.,   2.,   0.,  11.,   8.,   0.],
       [  0.,   4.,  12.,   0.,   0.,   8.,   8.,   0.],
       [  0.,   5.,   8.,   0.,   0.,   9.,   8.,   0.],
       [  0.,   4.,  11.,   0.,   1.,  12.,   7.,   0.],
       [  0.,   2.,  14.,   5.,  10.,  12.,   0.,   0.],
       [  0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]])

三、學習和預測

對於數字資料集(digits dataset),任務是預測一張圖片中的數字是什麼。數字資料集提供了0-9每一個數字的可能樣例,可以用它們來對位置的數字圖片進行擬合分類。

在scikit-learn中,用以分類的擬合(評估)函式是一個Python物件,具體有fit(X,Y)和predic(T)兩種成員方法。

其中一個擬合(評估)樣例是sklearn.svmSVC類,它實現了支援向量分類(SVC)。一個擬合(評估)函式的建構函式需要模型的引數,但是時間問題,我們將會把這個擬合(評估)函式作為一個黑箱:

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">>>><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import svm  
  2. >>>clf = svm.SVC(gamma=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001, C=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">100.)</span></span></span></span></code>  
>>>from sklearn import svm
>>>clf = svm.SVC(gamma=0.001, C=100.)

選擇模型引數
我們呼叫擬合(估測)例項clf作為我們的分類器。它現在必須要擬合模型,也就是說,他必須要學習模型。這可以通過把我們的訓練集傳遞給fit方法。作為訓練集,我們使用其中除最後一組的所有影象。我們可以通過Python的分片語法[:-1]來選取訓練集,這個操作將產生一個新陣列,這個陣列包含digits.dataz中除最後一組資料的所有例項。

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">>>>clf.fit(digits.data[:-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1], digits.target[:-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1])    
  2. SVC(C=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">100.0, cache_size=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">200, class_weight=<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">None, coef0=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.0, degree=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3,  
  3. gamma=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001, kernel=<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'rbf', max_iter=-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1, probability=<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">False,  
  4. random_state=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">None, shrinking=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">True, tol=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001, verbose=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">False)</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  
>>>clf.fit(digits.data[:-1], digits.target[:-1])  
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.001, kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)

現在你可以預測新的數值了。我們可以讓這個訓練器告訴我們digits資料集我們沒有作為訓練資料使用的最後一張影象是什麼數字。

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">>>>clf.predict(digits.data[-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1])  
  2. array([<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8])</span></span></code>  
>>>clf.predict(digits.data[-1])
array([8])

相應的圖片如下圖:
此處輸入圖片的描述

正如你所看到的,這是一個很有挑戰的任務:這張圖片的解析度很低。你同意分類器給出的答案嗎?

這個分類問題的完整示例在這裡識別手寫數字,你可以執行並使用它。[譯:看本文附錄]

四、模型持久化

可以使用Python的自帶模組——pickle來儲存scikit中的模型:

[python] view plain copyprint?
  1. <code class="hljs" style="margin:0px; padding:0px">>>><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import svm  
  2. >>><span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import datasets  
  3. >>>clf = svm.SVC()  
  4. >>>iris = datasets.load_iris()  
  5. >>>X, y = iris.data, iris.target  
  6. >>>clf.fit(X, y)    
  7. SVC(C=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.0, cache_size=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">200, class_weight=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">None, coef0=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.0, degree=<span cla