1. 程式人生 > >用SVM(有核和無核函式)進行MNIST手寫字型的分類

用SVM(有核和無核函式)進行MNIST手寫字型的分類

1.普通SVM分類MNIST資料集

 1 #匯入必備的包
 2 import numpy as np
 3 import struct
 4 import matplotlib.pyplot as plt
 5 import os
 6 ##載入svm模型
 7 from sklearn import svm
 8 ###用於做資料預處理
 9 from sklearn import preprocessing
10 import time
11 
12 #載入資料的路徑
13 path='./dataset/mnist/raw'
14 def load_mnist_train(path, kind='
train'): 15 labels_path = os.path.join(path,'%s-labels-idx1-ubyte'% kind) 16 images_path = os.path.join(path,'%s-images-idx3-ubyte'% kind) 17 with open(labels_path, 'rb') as lbpath: 18 magic, n = struct.unpack('>II',lbpath.read(8)) 19 labels = np.fromfile(lbpath,dtype=np.uint8)
20 with open(images_path, 'rb') as imgpath: 21 magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16)) 22 images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784) 23 return images, labels 24 def load_mnist_test(path, kind='t10k'): 25 labels_path = os.path.join(path,'
%s-labels-idx1-ubyte'% kind) 26 images_path = os.path.join(path,'%s-images-idx3-ubyte'% kind) 27 with open(labels_path, 'rb') as lbpath: 28 magic, n = struct.unpack('>II',lbpath.read(8)) 29 labels = np.fromfile(lbpath,dtype=np.uint8) 30 with open(images_path, 'rb') as imgpath: 31 magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16)) 32 images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784) 33 return images, labels 34 train_images,train_labels=load_mnist_train(path) 35 test_images,test_labels=load_mnist_test(path) 36 37 X=preprocessing.StandardScaler().fit_transform(train_images) 38 X_train=X[0:60000] 39 y_train=train_labels[0:60000] 40 41 print(time.strftime('%Y-%m-%d %H:%M:%S')) 42 model_svc = svm.LinearSVC() 43 #model_svc = svm.SVC() 44 model_svc.fit(X_train,y_train) 45 print(time.strftime('%Y-%m-%d %H:%M:%S')) 46 47 ##顯示前30個樣本的真實標籤和預測值,用圖顯示 48 x=preprocessing.StandardScaler().fit_transform(test_images) 49 x_test=x[0:10000] 50 y_pred=test_labels[0:10000] 51 print(model_svc.score(x_test,y_pred)) 52 y=model_svc.predict(x) 53 54 fig1=plt.figure(figsize=(8,8)) 55 fig1.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0.05,wspace=0.05) 56 for i in range(100): 57 ax=fig1.add_subplot(10,10,i+1,xticks=[],yticks=[]) 58 ax.imshow(np.reshape(test_images[i], [28,28]),cmap=plt.cm.binary,interpolation='nearest') 59 ax.text(0,2,"pred:"+str(y[i]),color='red') 60 #ax.text(0,32,"real:"+str(test_labels[i]),color='blue') 61 plt.show()

2.執行結果:

開始時間:2018-11-17 08:31:09

結束時間:2018-11-17 08:53:04

用時:21分55秒

精度:0.9122

預測圖片: