1. 程式人生 > >SVM 解決類別不平衡問題(scikit_learn)

SVM 解決類別不平衡問題(scikit_learn)

在支援向量機中,C 是負責懲罰錯誤分類資料的超引數。

解決資料類別不平衡的一個方法就是使用基於類別增加權重的C

Cj=Cwj

其中,C是誤分類的懲罰項,wj是與類別 j 的出現頻率成反比的權重引數,Cj 就是類別 j 對應的 加權C

主要思路就是增大誤分類 少數類別 帶來的影響,保證 少數類別 的分類正確性,避免被多數類別掩蓋

在scikit-learn 中,使用 svc 方法時,可以通過設定引數

class_weight=’balanced’

實現上述加權功能

引數‘balanced’ 會自動按照以下公式計算權值:

wj=nknj

其中,wj 為類別 j 對應權值,n 為資料總數,k為類別數量,即資料有k 個種類,nj是類別 j 的資料個數

0.匯入庫

# Load libraries
from sklearn.svm import SVC
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
import numpy as np

1、載入Iris Flower 資料集

#只加載兩個類別的資料,兩類,各50個
iris = datasets.load_iris()
X = iris.data[:100
,:] y = iris.target[:100]

2.不均衡化資料集

# 刪掉前四十個資料,資料總數變為60個
X = X[40:,:]
y = y[40:]

# 類別為0的類別不變,類別不為0的全部變為1
y = np.where((y == 0), 0, 1)
y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

可以看到,有60個數據,10個為類別0,50個為類別1

3.特徵標準化

# Standarize features
scaler = StandardScaler()
X_std = scaler.fit_transform(X)

4.使用加權類別訓練SVM分類器

# Create support vector classifier
svc = SVC(kernel='linear', class_weight='balanced', C=1.0, random_state=0)

# Train classifier
model = svc.fit(X_std, y)

翻譯自Chris Albon 部落格
原文地址