1. 程式人生 > >Sklearn__SVM實現手寫數字識別

Sklearn__SVM實現手寫數字識別

1、 資料準備

from sklearn.model_selection import StratifiedShuffleSplit
import pandas as pd
import numpy  as np
from sklearn.datasets import fetch_mldata

class Data_need():
	def __init__(self, percent, data_name):
		self.percent = percent
		self.data_name = data_name

	def get_data(self):
		data_home = r'D:\Python_data\python Data\sklearn'
mnist = fetch_mldata(self.data_name, data_home=data_home) return mnist['data'], mnist['target'] ## 打亂資料集 def random_data(self, x, y): mnist_train, mnist_test = 0, 0 ## 建立DataFrame data_y = pd.DataFrame(y, columns=['y']) n = len(x[0]) data_x = pd.DataFrame(x, columns=list(range(n))) mnist_data =
pd.merge(data_x, data_y, right_index=True, left_index=True) ## 分層取樣 split = StratifiedShuffleSplit(n_splits=1, test_size = self.percent, random_state=42) for train_index, test_index in split.split(mnist_data, mnist_data['y']): mnist_train = mnist_data.loc[train_index,:] mnist_test = mnist_data.
loc[test_index,:] return mnist_train, mnist_test def train_test_data(self, train, test): # 將畫素資料變為二值變數 return (np.array(train.iloc[:,:-1]) != 0)*1, np.array(train['y']), (np.array(test.iloc[:,:-1])!= 0)*1, np.array(test['y']) if __name__ == '__main__': data_need = Data_need(0.3, 'MNIST original') x, y = data_need.get_data() train, test = data_need.random_data(x, y) x_train_in, y_train_in, x_test_in, y_test_in = data_need.train_test_data(train, test)

2、檢視資料及模型訓練

模型採用ovr (ova)SMV模型

from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
plt.style.use('ggplot')

def to_plot(num, n):
	"""
	num: 想要繪製的數值
	n :第幾個樣本
	"""
	plt_x_array = x_train_in[y_train_in == num]
	some_digit = plt_x_array[n]
	some_digit_image = some_digit.reshape(28, 28)
	plt.imshow(some_digit_image, cmap=plt.cm.binary, interpolation='nearest')
	plt.axis('off')
	plt.show()


if __name__ == '__main__' :
	to_plot(8, 10)
	ova_svm_clf = LinearSVC(loss='hinge', C=5, multi_class='ovr')
	ova_svm_clf.fit(x_train_in, y_train_in)
	## 交叉驗證出預測
	y_prd = cross_val_predict(ova_svm_clf, x_train_in, y_train_in, cv=3)
	## 評估 混淆矩陣
	conf_m = confusion_matrix(y_train_in, y_prd)


在這裡插入圖片描述

3、模型評估


### 整體的準確率
def clf_correct(y_train, y_prd):
	return sum((y_train - y_prd) == 0) / len(y_train)


class plot_conf_m():
	def __init__(self, conf_m):
		self.conf_m = conf_m

	def plt_conf_m(self):
		## 用matshow()函式繪製出混淆矩陣
		plt.matshow(self.conf_m, cmap=plt.cm.gray)

	def plt_error_conf_m(self):
		## 關注誤差資料的影象呈現
		row_sums = self.conf_m.sum(axis=1, keepdims=True)
		norm_conf_m = self.conf_m / row_sums
		## 用0 將正確分類覆蓋 檢視那個類分類特別不準
		np.fill_diagonal(norm_conf_m, 0)
		plt.matshow(norm_conf_m, cmap=plt.cm.gray)


if __name__ == '__main__':
	print("整體準確性:{}".format(clf_correct(y_train_in, y_prd)))
	plt_confm = plot_conf_m(conf_m)
	plt_confm.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()

##  整體準確性:0.902795918367347

從下面兩個混淆矩陣中可以看出 錯誤分類分佈比較平均,還待提高,所以增大C 進行重新擬合
在這裡插入圖片描述

4、模型修正及預測

1. 模型修正

if __name__ == '__main__' :
	ova_svm_clf_fix = LinearSVC(loss='hinge', C=10, multi_class='ovr')
	ova_svm_clf_fix.fit(x_train_in, y_train_in)
	## 交叉驗證出預測
	y_prd_fix = cross_val_predict(ova_svm_clf_fix, x_train_in, y_train_in, cv=3)
	## 評估 混淆矩陣
	conf_m_fix = confusion_matrix(y_train_in, y_prd_fix)

	print("整體準確性:{}".format(clf_correct(y_train_in, y_prd_fix)))
	plt_confm_fix = plot_conf_m(conf_m_fix)
	plt_confm_fix.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm_fix.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()

## 整體準確率0.91204

增大C 雖然提高了整體的準確率,對準確率並沒有明顯好轉,可見線性核對該資料分類效果不明顯。所以改用高斯核進行擬合。
在這裡插入圖片描述

from sklearn.svm import SVC
from sklearn.metrics import classification_report

if __name__ == '__main__': # ova
	ova_svm_clf_rbf = SVC(kernel='rbf',gamma = 'auto', C = 15, cache_size= 8000, decision_function_shape = 'ovr')
	ova_svm_clf_rbf.fit(x_train_in, y_train_in)
	y_prd_rbf = ova_svm_clf_rbf.predict(x_train_in)
	print('整體準確率{}'.format(clf_correct(y_train_in, y_prd_rbf))) # 0.90
	conf_m_rbf = confusion_matrix(y_train_in, y_prd_rbf)
	plt_confm_rbf = plot_conf_m(conf_m_rbf)
	plt_confm_rbf.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm_rbf.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()
	# 輸出詳細報告
	print(classification_report(y_train_in, y_prd_rbf))

"""
# 整體準確率:0.9831632653061224
             precision    recall  f1-score   support
        0.0       0.99      0.99      0.99      4832
        1.0       0.99      0.99      0.99      5514
        2.0       0.98      0.99      0.99      4893
        3.0       0.98      0.97      0.97      4999
        4.0       0.98      0.98      0.98      4777
        5.0       0.98      0.98      0.98      4419
        6.0       0.99      0.99      0.99      4813
        7.0       0.98      0.98      0.98      5105
        8.0       0.98      0.98      0.98      4777
        9.0       0.98      0.97      0.97      4871
avg / total       0.98      0.98      0.98     49000

"""

高斯核的準確率明顯提升了,但對9和4 與 3和5 的識別還是不是十分精確
在這裡插入圖片描述

2. 模型預測

if __name__ == '__main__' :
	y_test_prd = ova_svm_clf_fix.predict(x_test)
	print("整體準確性:{}".format(clf_correct(y_train, y_test_prd)))
	plt_confm_test = plot_conf_m(conf_m)
	plt_confm_test.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm_test.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()
	# 輸出詳細報告
	print(classification_report(y_test_in, y_test_prd))

"""
# 整體準確性:0.9615238095238096
            precision    recall  f1-score   support
        0.0       0.97      0.99      0.98      2071
        1.0       0.97      0.98      0.98      2363
        2.0       0.96      0.97      0.96      2097
        3.0       0.95      0.95      0.95      2142
        4.0       0.96      0.96      0.96      2047
        5.0       0.96      0.94      0.95      1894
        6.0       0.97      0.98      0.97      2063
        7.0       0.97      0.96      0.97      2188
        8.0       0.95      0.95      0.95      2048
        9.0       0.94      0.94      0.94      2087
avg / total       0.96      0.96      0.96     21000

"""

在這裡插入圖片描述