1. 程式人生 > >Python實現支援向量機(SVM) MNIST資料集

Python實現支援向量機(SVM) MNIST資料集

Python實現支援向量機(SVM) MNIST資料集

SVM的原理這裡不講,大家自己可以查閱相關資料。

下面是利用sklearn庫進行svm訓練MNIST資料集,準確率可以達到90%以上。


from sklearn import svm
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
train_num = 10000
test_num = 1000
x_train = mnist.train.images y_train = mnist.train.labels x_test = mnist.test.images y_test = mnist.test.labels # 獲取一個支援向量機模型 predictor = svm.SVC(gamma='scale', C=1.0, decision_function_shape='ovr', kernel='rbf') # 把資料丟進去 predictor.fit(x_train[:train_num], y_train[:train_num]) # 預測結果 result = predictor.
predict(x_test[:test_num]) # 準確率估計 accurancy = np.sum(np.equal(result, y_test[:test_num])) / test_num print(accurancy)

SVC函式的引數解析

gamma

支援向量機的間隔,即是超平面距離不同類別的最小距離,是一個float型別的值,可以自己規定,也可以用SVM自己的值,有兩個選擇。

  • auto 選擇auto時,gamma = 1/feature_num ,也就是特徵的數目分之1
  • scale 選擇scale時,gamma = 1/(feature_num * X.std()), 特徵數目乘樣本標準差分之1. 一般來說,scael比auto結果準確。
C

看到過SVM公式推導的同學對C一定不陌生,它是鬆弛變數的係數,稱為懲罰係數,用來調整容忍鬆弛度,當C越大,說明該模型對分類錯誤更加容忍,也就是為了避免過擬合。

decision_function_shape

兩個選擇

  • ovr one vs rest 將一個類別與其他所有類別進行劃分
  • ovo one vs one 兩兩劃分

kernel

核函式的選擇

  • 當樣本線性可分時,一般選擇linear 線性核函式
  • 當樣本線性不可分時,有很多選擇,這裡選擇rbf 即徑向基函式,又稱高斯核函式。