利用Python實現k最近鄰演算法 並識別手寫數字(詳細註釋)
阿新 • • 發佈:2019-02-06
K最近鄰(k-Nearest Neighbor,KNN)分類演算法,是一個理論上比較成熟的方法,也是較為簡單的機器學習演算法之一。該方法的思路是:如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。
K最近鄰演算法(k-nearest neighbors)是一種有監督分類的機器學習演算法。顧名思義,其演算法主體思想就是根據距離相近的鄰居類別, 來判定自己的所屬類別。演算法的前提是需要有一個已被標記類別的訓練資料集,具體的計算步驟分為一下三步: 1、計算測試物件與訓練集中所有物件的距離,可以是歐式距離、餘弦距離等,比較常用的是較為簡單的歐式距離; 2、找出上步計算的距離中最近的K個物件,作為測試物件的鄰居; 3、找出K個物件中出現頻率最高的物件,其所屬的類別就是該測試物件所屬的類別。
實現K最近鄰演算法
#! /usr/bin/env python # -*- coding:utf-8 -*- """ k-NearestNeighbor k近臨演算法的python實現 """ import numpy as np def classify(X, DATASET, LABELS, K): # 大寫引數表示常量,不可改變 distances = np.sqrt(np.sum(np.square(DATASET - X), axis=1)) # 計算距離矩陣 len_dis = len(distances) # 得到distances數目 labels = [] # 儲存標籤 for i in range(0, K): min_value = distances[i] min_value_idx = i for j in range(i + 1, len_dis): if distances[j] < min_value: min_value = distances[j] min_value_idx = j distances[i], distances[min_value_idx] = distances[min_value_idx], distances[i] labels.append(LABELS[min_value_idx]) # 選擇排序挑選出前k個最值 # 用labels儲存前k個最小距離的標籤 C = labels[0] max_count = 0 for label in labels: count = labels.count(label) if count > max_count: max_count = count C = label # 求前k個label中,重複次數最多的label,並返回 return C
影象處理相關函式
#! /usr/bin/env python # -*- coding:utf-8 -*- """ 和影象操作有關函式 """ import matplotlib.pyplot as plt import numpy as np def img2vector(filename): # 將影象轉向量 vector = np.zeros([1024], int) # 定義返回的向量,大小為1*1024 lines = None with open(filename, 'r') as f: lines = f.readlines() # 讀取32*32數字檔案 for i in range(32): for j in range(32): vector[i * 32 + j] = lines[i][j] # 將資訊存放在vector中 return vector def img2mat(filename): # 將影象轉矩陣 mat = np.zeros([32, 32], int) # 定義返回的矩陣,大小為32*32 lines = None with open(filename, 'r') as f: lines = f.readlines() # 讀取32*32數字檔案 for i in range(32): for j in range(32): mat[i, j] = lines[i][j] # 將資訊存放在mat中 return mat def show_img(mat): # 顯示影象 plt.imshow(mat) # plt.axis('off') plt.show()
識別手寫數字
#! /usr/bin/env python
# -*- coding:utf-8 -*-
"""
基於knn演算法的手寫數字識別
"""
from os import listdir
import matplotlib.pyplot as plt
import numpy as np
from img import img2mat, img2vector, show_img
from knn import classify
train_digits_path = '/home/user/digits/trainingDigits/'
test_digits_path = '/home/user/digits/testDigits/'
def read_dataSet(path):
file_list = listdir(path)
# 獲取資料夾下的所有檔案路徑
num_files = len(file_list)
# 統計檔案數目
dataset = np.zeros([num_files, 1024], int)
# 用於存放所有的數字檔案
labels = np.zeros([num_files])
# 用於存放對應的標籤
for i in range(num_files):
# 遍歷所有的檔案
file_path = file_list[i]
# 獲取檔名稱
digit = int(file_path.split('_')[0])
# 通過檔名獲取標籤
labels[i] = digit
# 存放標籤
dataset[i] = img2vector(path + '/' + file_path)
# 存放資料
return dataset, labels
# 讀取訓練集
train_dataset, train_labels = read_dataSet(train_digits_path)
# 讀取測試集
test_dataset, test_labels = read_dataSet(test_digits_path)
def classify_test_dataset(k):
# 對測試集進行識別
test_num = len(test_dataset)
# 測試集的數目
error_num = 0
# 錯誤數目
for data, label in zip(test_dataset, test_labels):
res = classify(data, train_dataset, train_labels, k)
# 對測試集進行預測
if res != label:
error_num += 1
# 若預測錯誤,則計數器加一
print("total:{},error num:{},error rate:{}".format(test_num, error_num, error_num / test_num))
def classify_test_data(filename):
# 對測試集合,單個檔案進行識別
file_path = test_digits_path + filename
try:
data = img2vector(file_path)
res = classify(data, train_dataset, train_labels, 3)
return int(res)
except FileNotFoundError:
print("No such file.")
if __name__ == '__main__':
# 測試
'''
mats = []
vs = []
for idx in range(4):
filename = str(idx) + '_2.txt'
mat = img2mat(test_digits_path + filename)
mats.append(mat)
pv = classify_test_data(filename)
rv = int(filename.split('_')[0])
vs.append([pv,rv])
for i in range(len(mats)):
pv,rv = vs[i]
plt.subplot(2,2,i+1)
plt.xlabel("pv:"+str(pv)+",rv:"+str(rv))
plt.imshow(mats[i])
plt.show()
'''
classify_test_dataset(3)
執行結果示例:
參考文獻
[1] Peter Harrington. Machine Learning in Action
[2] 維基百科 .K-nearest_neighbors_algorithm
[3] 百度百科 .k近鄰演算法
[4] New York University. data/digits