1. 程式人生 > >支援向量機SVM理解篇

支援向量機SVM理解篇

 

Posted
by 子顥 on August 14,2018
#演算法原理
支援向量機(Support Vector Machine,SVM)是機器學習中的最經典也是最重要的分類方法之一。



樣本空間中任一點x到超平面的距離為:


現在我們希望求解上式來得到最大間隔超平面所對應的模型:f(x) = w * x + b




下面還是通過一個具體例子感受一下線性可分支援向量機的訓練過程。


核函式(kernel trick),我們線上性迴歸等幾個小節中曾經提到過核函式的概念,polynomial也是核函式的一種。






那麼我們在實際應用當中到底應該怎樣選擇核函式呢?告訴大家一條鐵律:首先選擇線性核(LinearSVC),如果訓練集不太大,再試一下RBF核。 只有一個對稱函式所對應的核矩陣半正定,它才能作為核函式使用(亦即才能拆成對映函式的內積)。



為了解決這個問題,可以對每個樣本點引進一個對應的鬆弛變數,用以表徵樣本不滿足約束的程度,使函式間隔加上鬆弛變數大於等於1。這樣,約束條件變為:


我們從損失函式的角度看,gamma表示樣本不滿足約束的程度,如果樣本滿足約束,那麼gamma值為0。所以這實際上是hinge損失:hinge(z) = max(0, 1-z)。上面7.31式可以改寫為:

加號後的一項就是SVM的hinge損失函式,加號前的一項恰好是L2正則。如果我們將上式中的損失函式變為對數損失,那麼恰好變成了加了L2正則的邏輯迴歸。

既然講到了這裡,那我們不防繼續深入一下,試著從損失函式的角度探討SVM和LR各自的特點是什麼?
1.    因為LR和SVM的優化目標接近(損失函式漸進趨同),所以通常情況下他們的表現也相當。
2.    SVM的hinge損失函式在z大於1後,都是平坦的0區域,這使得SVM的解具有稀疏性(只與支援向量有關,函式影象拐點位置);而LR的log損失是光滑的單調遞減函式,不能匯出類似支援向量的概念。因此LR的求解過程依賴於所有樣本點,開銷更大(尤其是需要用到核函式時)。
3.    SVM和LR都是使用一個最優分隔超平面進行樣本點的劃分,且距離分隔超平面越遠的點對模型的訓練影響越小。SVM是完全無影響(平坦的0區域),LR是影響較弱(損失函式漸進趨於0)。
4.    因為SVM的訓練只與支援向量點有關,所以資料unbalance對SVM幾乎無影響,而LR一般需要做樣本均衡處理。
5.    LR迴歸的輸出具有自然的概率含義,SVM的輸出是樣本點到最優超平面的距離,欲得到概率需要進行特殊處理。
我們依然通過拉格朗日乘子法求解加入鬆弛變數的SVM:



支援向量迴歸(SVR):找到兩條平行直線帶,帶內點的損失為0,帶上的點是儘可能多的支援向量。


訓練方法與SVM相同。
#模型訓練
程式碼地址 https://github.com/qianshuang/ml-exp
def train():
   print("start training...")
   # 處理訓練資料
   # train_feature, train_target = process_file(train_dir, word_to_id, cat_to_id)  # 詞頻特徵
   train_feature, train_target = process_tfidf_file(train_dir, word_to_id, cat_to_id)  # TF-IDF特徵
   # 模型訓練
   model.fit(train_feature, train_target)
def test():
   print("start testing...")
   # 處理測試資料
   test_feature, test_target = process_file(test_dir, word_to_id, cat_to_id)
   # test_predict = model.predict(test_feature)  # 返回預測類別
   test_predict_proba = model.predict_proba(test_feature)    # 返回屬於各個類別的概率
   test_predict = np.argmax(test_predict_proba, 1)  # 返回概率最大的類別標籤
   # accuracy
   true_false = (test_predict == test_target)
   accuracy = np.count_nonzero(true_false) / float(len(test_target))
   print()
   print("accuracy is %f" % accuracy)
   # precision    recall  f1-score
   print()
   print(metrics.classification_report(test_target, test_predict, target_names=categories))
   # 混淆矩陣
   print("Confusion Matrix...")
   print(metrics.confusion_matrix(test_target, test_predict))
if not os.path.exists(vocab_dir):
   # 構建詞典表
   build_vocab(train_dir, vocab_dir)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
# kNN
# model = neighbors.KNeighborsClassifier()
# decision tree
# model = tree.DecisionTreeClassifier()
# random forest
# model = ensemble.RandomForestClassifier(n_estimators=10)  # n_estimators為基決策樹的數量,一般越大效果越好直至趨於收斂
# AdaBoost
# model = ensemble.AdaBoostClassifier(learning_rate=1.0)  # learning_rate的作用是收縮基學習器的權重貢獻值
# GBDT
# model = ensemble.GradientBoostingClassifier(n_estimators=10)
# xgboost
# model = xgboost.XGBClassifier(n_estimators=10)
# Naive Bayes
model = naive_bayes.MultinomialNB()
# logistic regression
# model = linear_model.LogisticRegression()   # ovr
# model = linear_model.LogisticRegression(multi_class="multinomial", solver="lbfgs")  # softmax迴歸
# SVM
model = svm.LinearSVC()  # 線性,無概率結果
model = svm.SVC(probability=True)  # 核函式,訓練慢
train()
test()
執行結果:
read_category...
read_vocab...
start training...
start testing...
accuracy is 0.970000
            precision    recall  f1-score   support
        遊戲       1.00      1.00      1.00       104
        時政       0.92      0.93      0.92        94
        體育       1.00      0.99      1.00       116
        娛樂       0.99      0.99      0.99        89
        時尚       1.00      0.99      0.99        91
        教育       0.97      0.94      0.96       104
        家居       0.91      0.96      0.93        89
        財經       0.96      0.96      0.96       115
        科技       1.00      0.99      0.99        94
        房產       0.94      0.96      0.95       104
avg / total       0.97      0.97      0.97      1000
Confusion Matrix...
[[104   0   0   0   0   0   0   0   0   0]
[  0  87   0   0   0   0   1   3   0   3]
[  0   1 115   0   0   0   0   0   0   0]
[  0   1   0  88   0   0   0   0   0   0]
[  0   0   0   0  90   1   0   0   0   0]
[  0   1   0   1   0  98   3   0   0   1]
[  0   1   0   0   0   2  85   1   0   0]
[  0   1   0   0   0   0   2 110   0   2]
[  0   0   0   0   0   0   1   0  93   0]
[  0   3   0   0   0   0   1   0   0 100]]

社群
QQ交流群

微信公眾號
瞭解更多幹貨文章,可以關注小程式八斗問答