Python實現支援向量機(SVM) MNIST資料集
阿新 • • 發佈:2018-11-20
Python實現支援向量機(SVM) MNIST資料集
SVM的原理這裡不講,大家自己可以查閱相關資料。
下面是利用sklearn庫進行svm訓練MNIST資料集,準確率可以達到90%以上。
from sklearn import svm
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
train_num = 10000
test_num = 1000
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
# 獲取一個支援向量機模型
predictor = svm.SVC(gamma='scale', C=1.0, decision_function_shape='ovr', kernel='rbf')
# 把資料丟進去
predictor.fit(x_train[:train_num], y_train[:train_num])
# 預測結果
result = predictor. predict(x_test[:test_num])
# 準確率估計
accurancy = np.sum(np.equal(result, y_test[:test_num])) / test_num
print(accurancy)
SVC函式的引數解析
gamma
支援向量機的間隔,即是超平面距離不同類別的最小距離,是一個float型別的值,可以自己規定,也可以用SVM自己的值,有兩個選擇。
auto
選擇auto時,gamma = 1/feature_num ,也就是特徵的數目分之1scale
選擇scale時,gamma = 1/(feature_num * X.std()), 特徵數目乘樣本標準差分之1. 一般來說,scael比auto結果準確。
C
看到過SVM公式推導的同學對C一定不陌生,它是鬆弛變數的係數
,稱為懲罰係數,用來調整容忍鬆弛度
,當C越大,說明該模型對分類錯誤更加容忍,也就是為了避免過擬合。
decision_function_shape
兩個選擇
ovr
one vs rest 將一個類別與其他所有類別進行劃分ovo
one vs one 兩兩劃分
kernel
核函式的選擇
- 當樣本線性可分時,一般選擇
linear
線性核函式 - 當樣本線性不可分時,有很多選擇,這裡選擇
rbf
即徑向基函式,又稱高斯核函式。