1. 程式人生 > >一般knn演算法識別mnist資料集(程式碼)

一般knn演算法識別mnist資料集(程式碼)

本來是想弄個kd tree來玩玩knn的,但是mnist這樣的資料集真是不好按維切分。把資料打印出來看了下,貌似灰度值大於3的都算是手寫的印跡,著實不能取中值。既然這樣,先拿一般的knn方法識別一下,看看效果和執行效率,再想辦法這算一下mnist,玩玩kd tree吧。knn的基本原理在k-means、GMM聚類、KNN原理概述 有介紹,比較全的原理介紹在http://www.hankcs.com/ml/k-nearest-neighbor-method.html
下面是用knn 識別mnist資料集的程式碼,程式碼包括mnist的下載和抽取,以及knn測試,並計算了測試1000張圖片所花費的時間。

#coding:utf-8
import numpy as np import os import gzip from six.moves import urllib import operator from datetime import datetime SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
#下載mnist資料集,仿照tensorflow的base.py中的寫法。 def maybe_download(filename, path, source_url): if not os.path.exists(path): os.makedirs(path) filepath = os.path.join(path, filename) if not os.path.exists(filepath): urllib.request.urlretrieve(source_url, filepath) return filepath #按32位讀取,主要為讀校驗碼、圖片數量、尺寸準備的
#仿照tensorflow的mnist.py寫的。 def _read32(bytestream): dt = np.dtype(np.uint32).newbyteorder('>') return np.frombuffer(bytestream.read(4), dtype=dt)[0] #抽取圖片,並按照需求,可將圖片中的灰度值二值化,按照需求,可將二值化後的資料存成矩陣或者張量 #仿照tensorflow中mnist.py寫的 def extract_images(input_file, is_value_binary, is_matrix): with gzip.open(input_file, 'rb') as zipf: magic = _read32(zipf) if magic !=2051: raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, input_file.name)) num_images = _read32(zipf) rows = _read32(zipf) cols = _read32(zipf) print magic, num_images, rows, cols buf = zipf.read(rows * cols * num_images) data = np.frombuffer(buf, dtype=np.uint8) if is_matrix: data = data.reshape(num_images, rows*cols) else: data = data.reshape(num_images, rows, cols) if is_value_binary: return np.minimum(data, 1) else: return data #抽取標籤 #仿照tensorflow中mnist.py寫的 def extract_labels(input_file): with gzip.open(input_file, 'rb') as zipf: magic = _read32(zipf) if magic != 2049: raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, input_file.name)) num_items = _read32(zipf) buf = zipf.read(num_items) labels = np.frombuffer(buf, dtype=np.uint8) return labels # 一般的knn分類,跟全部資料同時計算一般距離,然後找出最小距離的k張圖,並找出這k張圖片的標籤,標籤佔比最大的為newInput的label #copy大神http://blog.csdn.net/zouxy09/article/details/16955347的 def kNNClassify(newInput, dataSet, labels, k): numSamples = dataSet.shape[0] # shape[0] stands for the num of row init_shape = newInput.shape[0] newInput = newInput.reshape(1, init_shape) #np.tile(A,B):重複A B次,相當於重複[A]*B #print np.tile(newInput, (numSamples, 1)).shape diff = np.tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise squaredDiff = diff ** 2 # squared for the subtract squaredDist = np.sum(squaredDiff, axis = 1) # sum is performed by row distance = squaredDist ** 0.5 sortedDistIndices = np.argsort(distance) classCount = {} # define a dictionary (can be append element) for i in xrange(k): ## step 3: choose the min k distance voteLabel = labels[sortedDistIndices[i]] ## step 4: count the times labels occur # when the key voteLabel is not in dictionary classCount, get() # will return 0 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 ## step 5: the max voted class will return maxCount = 0 maxIndex = 0 for key, value in classCount.items(): if value > maxCount: maxCount = value maxIndex = key return maxIndex maybe_download('train_images', 'data/mnist', SOURCE_URL+TRAIN_IMAGES) maybe_download('train_labels', 'data/mnist', SOURCE_URL+TRAIN_LABELS) maybe_download('test_images', 'data/mnist', SOURCE_URL+TEST_IMAGES) maybe_download('test_labels', 'data/mnist', SOURCE_URL+TEST_LABELS) # 主函式,先讀圖片,然後用於測試手寫數字 #copy大神http://blog.csdn.net/zouxy09/article/details/16955347的 def testHandWritingClass(): ## step 1: load data print "step 1: load data..." train_x = extract_images('data/mnist/train_images', True, True) train_y = extract_labels('data/mnist/train_labels') test_x = extract_images('data/mnist/test_images', True, True) test_y = extract_labels('data/mnist/test_labels') ## step 2: training... print "step 2: training..." pass ## step 3: testing print "step 3: testing..." a = datetime.now() numTestSamples = test_x.shape[0] matchCount = 0 test_num = numTestSamples/10 for i in xrange(test_num): predict = kNNClassify(test_x[i], train_x, train_y, 3) if predict == test_y[i]: matchCount += 1 if i % 100 == 0: print "完成%d張圖片"%(i) accuracy = float(matchCount) / test_num b = datetime.now() print "一共運行了%d秒"%((b-a).seconds) ## step 4: show the result print "step 4: show the result..." print 'The classify accuracy is: %.2f%%' % (accuracy * 100) if __name__ == '__main__': testHandWritingClass()

執行後的結果如下:

step 1: load data...
2051 60000 28 28
2051 10000 28 28
step 2: training...
step 3: testing...
完成0張圖片
完成100張圖片
完成200張圖片
完成300張圖片
完成400張圖片
完成500張圖片
完成600張圖片
完成700張圖片
完成800張圖片
完成900張圖片
一共運行了234step 4: show the result...
The classify accuracy is: 96.20%

1000張圖片執行時間234秒,時間開銷大於簡單的cnn,識別率高於96.2%,僅高於softmax迴歸,後者只有92%,多層感知機能達到98%的識別率,且訓練速度快,測試更快。