1. 程式人生 > >機器學習演算法——K鄰近演算法

機器學習演算法——K鄰近演算法

#-*-coding=utf-8-*-
__author__ = 'whf'
from numpy import *
import operator
def classify (inx,dataSet,labels,k):
    #得到資料集的行數  shape方法用來得到矩陣或陣列的維數
    dataSetSize = dataSet.shape[0]
    #tile:numpy中的函式。tile將原來的一個數組,擴充成了dataSetSize行1列的陣列。diffMat得到了目標與訓練數值之間的差值。
    diffMat = tile(inx,(dataSetSize,1))-dataSet
    #計算差值的平方
    sqDiffMat = diffMat**2
    #計算差值平方和
    sqDistances = sqDiffMat.sum(axis = 1)
    #計算距離
    distances = sqDistances**0.5
    #得到排序後坐標的序號  argsort方法得到矩陣中每個元素的排序序號
    sortedDistIndicies = distances.argsort()
    classcount = {}
    for i in range(k):
        #找到前k個距離最近的座標的標籤
        voteIlabel = labels[sortedDistIndicies[i]]
        #在字典中設定鍵值對: 標籤:出現的次數
        classcount [voteIlabel] = classcount.get(voteIlabel,0)+1 #如果voteIlable標籤在classcount中就得到它的值加1否則就是0+1
    # 對字典中的類別出現次數進行排序,classCount中儲存的事 key-value,其中key就是label,value就是出現的次數
    # 所以key=operator.itemgetter(1)選中的是value,也就是對次數進行排序 reverse = True表示降序排列
    sortedClassCount = sorted(classcount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
print classify([0.1,0.1],group,labels,3)