1. 程式人生 > >利用Python實現k最近鄰演算法 並識別手寫數字(詳細註釋)

利用Python實現k最近鄰演算法 並識別手寫數字(詳細註釋)

    K最近鄰(k-Nearest Neighbor,KNN)分類演算法,是一個理論上比較成熟的方法,也是較為簡單的機器學習演算法之一。該方法的思路是:如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。

K最近鄰演算法(k-nearest neighbors)是一種有監督分類的機器學習演算法。顧名思義,其演算法主體思想就是根據距離相近的鄰居類別,
來判定自己的所屬類別。演算法的前提是需要有一個已被標記類別的訓練資料集,具體的計算步驟分為一下三步:
1、計算測試物件與訓練集中所有物件的距離,可以是歐式距離、餘弦距離等,比較常用的是較為簡單的歐式距離;
2、找出上步計算的距離中最近的K個物件,作為測試物件的鄰居;
3、找出K個物件中出現頻率最高的物件,其所屬的類別就是該測試物件所屬的類別。


實現K最近鄰演算法

#! /usr/bin/env python
# -*- coding:utf-8 -*-

"""
k-NearestNeighbor
k近臨演算法的python實現
"""

import numpy as np


def classify(X, DATASET, LABELS, K):
    # 大寫引數表示常量,不可改變
    distances = np.sqrt(np.sum(np.square(DATASET - X), axis=1))
    # 計算距離矩陣
    len_dis = len(distances)
    # 得到distances數目
    labels = []
    # 儲存標籤
    for i in range(0, K):
        min_value = distances[i]
        min_value_idx = i
        for j in range(i + 1, len_dis):
            if distances[j] < min_value:
                min_value = distances[j]
                min_value_idx = j
        distances[i], distances[min_value_idx] = distances[min_value_idx], distances[i]
        labels.append(LABELS[min_value_idx])
    # 選擇排序挑選出前k個最值
    # 用labels儲存前k個最小距離的標籤
    C = labels[0]
    max_count = 0
    for label in labels:
        count = labels.count(label)
        if count > max_count:
            max_count = count
            C = label
    # 求前k個label中,重複次數最多的label,並返回
    return C

影象處理相關函式

#! /usr/bin/env python
# -*- coding:utf-8 -*-

"""
和影象操作有關函式
"""

import matplotlib.pyplot as plt
import numpy as np


def img2vector(filename):
    # 將影象轉向量
    vector = np.zeros([1024], int)
    # 定義返回的向量,大小為1*1024
    lines = None
    with open(filename, 'r') as f:
        lines = f.readlines()
    # 讀取32*32數字檔案
    for i in range(32):
        for j in range(32):
            vector[i * 32 + j] = lines[i][j]
    # 將資訊存放在vector中
    return vector


def img2mat(filename):
    # 將影象轉矩陣
    mat = np.zeros([32, 32], int)
    # 定義返回的矩陣,大小為32*32
    lines = None
    with open(filename, 'r') as f:
        lines = f.readlines()
    # 讀取32*32數字檔案
    for i in range(32):
        for j in range(32):
            mat[i, j] = lines[i][j]
    # 將資訊存放在mat中
    return mat


def show_img(mat):
    # 顯示影象
    plt.imshow(mat)
    # plt.axis('off')
    plt.show()

識別手寫數字

#! /usr/bin/env python
# -*- coding:utf-8 -*-

"""
基於knn演算法的手寫數字識別
"""

from os import listdir

import matplotlib.pyplot as plt
import numpy as np

from img import img2mat, img2vector, show_img
from knn import classify

train_digits_path = '/home/user/digits/trainingDigits/'
test_digits_path = '/home/user/digits/testDigits/'


def read_dataSet(path):
    file_list = listdir(path)
    # 獲取資料夾下的所有檔案路徑
    num_files = len(file_list)
    # 統計檔案數目
    dataset = np.zeros([num_files, 1024], int)
    # 用於存放所有的數字檔案
    labels = np.zeros([num_files])
    # 用於存放對應的標籤
    for i in range(num_files):
        # 遍歷所有的檔案
        file_path = file_list[i]
        # 獲取檔名稱
        digit = int(file_path.split('_')[0])
        # 通過檔名獲取標籤
        labels[i] = digit
        # 存放標籤
        dataset[i] = img2vector(path + '/' + file_path)
        # 存放資料
    return dataset, labels


# 讀取訓練集
train_dataset, train_labels = read_dataSet(train_digits_path)

# 讀取測試集
test_dataset, test_labels = read_dataSet(test_digits_path)


def classify_test_dataset(k):
    # 對測試集進行識別
    test_num = len(test_dataset)
    # 測試集的數目
    error_num = 0
    # 錯誤數目
    for data, label in zip(test_dataset, test_labels):
        res = classify(data, train_dataset, train_labels, k)
        # 對測試集進行預測
        if res != label:
            error_num += 1
        # 若預測錯誤,則計數器加一
    print("total:{},error num:{},error rate:{}".format(test_num, error_num, error_num / test_num))


def classify_test_data(filename):
    # 對測試集合,單個檔案進行識別
    file_path = test_digits_path + filename
    try:
        data = img2vector(file_path)
        res = classify(data, train_dataset, train_labels, 3)
        return int(res)
    except FileNotFoundError:
        print("No such file.")


if __name__ == '__main__':
    # 測試
    '''
    mats = []
    vs = []
    for idx in range(4):
        filename = str(idx) + '_2.txt'
        mat = img2mat(test_digits_path + filename)
        mats.append(mat)
        pv = classify_test_data(filename)
        rv = int(filename.split('_')[0])
        vs.append([pv,rv])

    for i in range(len(mats)):
        pv,rv = vs[i]
        plt.subplot(2,2,i+1)
        plt.xlabel("pv:"+str(pv)+",rv:"+str(rv))
        plt.imshow(mats[i])
    plt.show()    
    '''
    classify_test_dataset(3)

執行結果示例:



參考文獻

[1] Peter Harrington. Machine Learning in Action

[2] 維基百科 .K-nearest_neighbors_algorithm

[3] 百度百科 .k近鄰演算法

[4] New York University. data/digits