1. 程式人生 > >k近鄰分類演算法(kNN)

k近鄰分類演算法(kNN)

註明:部分內容來自維基百科

In both cases, the input consists of the k closest training examples in thefeature space. The output depends on whether k-NN is used for classification or regression:

  • In k-NN classification, the output is a class membership.An object is classified by a majority vote of its neighbors, with the object being assigned to the class most common among itsk
    nearest neighbors
    (k is a positive integer, typically small). If k = 1, then the object is simply assigned to the class of that single nearest neighbor.
  • In k-NN regression, the output is the property value for the object. This value is the average of the values of itsk nearest neighbors.

k-NN is a type of

instance-based learning, or lazy learning, where the function is only approximated locally and all computation is deferred until classification.The k-NN algorithm is among the simplest of allmachine learning algorithms.

Both for classification and regression, it can be useful to weight the contributions of the neighbors, so that the nearer neighbors contribute more to the average than the more distant ones. For example, a common weighting scheme consists in giving each neighbor a weight of 1/d

, where d is the distance to the neighbor.

The neighbors are taken from a set of objects for which the class (for k-NN classification) or the object property value (for k-NN regression) is known. This can be thought of as the training set for the algorithm, though no explicit training step is required.

A shortcoming of the k-NN algorithm is that it is sensitive to the local structure of the data.

The training examples are vectors in a multidimensional feature space, each with a class label.The training phase of the algorithm consists only of storing the feature vectors and class labels of the training samples.

In the classification phase, k is a user-defined constant, and an unlabeled vector (a query or test point) is classified by assigning the label which is most frequent among thek training samples nearest to that query point.

For discrete variables, such as for text classification, another metric can be used, such as theoverlap metric (or Hamming distance). Often, the classification accuracy of k-NN can be improved significantly if the distance metric is learned with specialized algorithms such as Large Margin Nearest Neighbor or Neighbourhood components analysis.

A drawback of the basic "majority voting" classification occurs when the class distribution is skewed. That is, examples of a more frequent class tend to dominate the prediction of the new example, because they tend to be common among thek nearest neighbors due to their large number. One way to overcome this problem is to weight the classification, taking into account the distance from the test point to each of itsk nearest neighbors. The class (or value, in regression problems) of each of the k nearest points is multiplied by a weight proportional to the inverse of the distance from that point to the test point. Another way to overcome skew is by abstraction in data representation. For example in a self-organizing map (SOM), each node is a representative (a center) of a cluster of similar points, regardless of their density in the original training data. K-NN can then be applied to the SOM.

                                                            

如上圖所示,最中間的圓點,如果是3NN,則屬於紅色三角形,如果是5NN,則屬於藍色正方形。這就是kNN最基本的思想。但是,kNN對於每一個待分類的點,都需要和全部資料點進行距離計算,計算量太大。

在下面,我們將通過一段python程式碼來演示kNN演算法。

#coding:utf-8

from numpy import *
import operator
import os


#建立開發用的小規模資料集
def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group, labels

#kNN分類演算法的核心函式
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#dataSet的總共的行數
    #計算輸入向量和資料集中每一個數據的歐式距離
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    #對距離進行排序,返回的是原來的相對位置
    sortedDistIndices = distances.argsort()
    #統計前k個最短的距離中,分類的情況
    classCount={}
    for i in range(k):
        voteIlabel = labels[sortedDistIndices[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

#將檔案中的資料轉換成矩陣
def file2matrix(filename):
    fr = open(filename, 'r')
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines)
    returnMat = zeros((numberOfLines,3))
    classLabelVector = []
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat, classLabelVector

#資料的前處理步驟:歸一化數值
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals, (m,1))
    normDataSet = normDataSet/tile(ranges,(m,1))
    return normDataSet, ranges, minVals

#約會網站的測試程式碼
def datingClassTest():
    hoRadio = 0.010
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    normMat,ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRadio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:], datingLabels[numTestVecs:m],3)
        print "the classifier came back with: %d, the real answer is : %d" % (classifierResult, datingLabels[i])
        if(classifierResult != datingLabels[i]):
            errorCount += 1.0
    print "the total error rate is : %f" % (errorCount/float(numTestVecs))

#約會網站預測函式
def classifyPerson():
    resultList = ['not at all', 'in small doses', 'in large doses']
    percentTats = float(raw_input("percentage of time spent playing video games?"))
    ffMiles = float(raw_input("frequent flier miles earned per year?"))
    iceCream = float(raw_input("liters of ice cream consumed per year?"))
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    inArr = array([ffMiles, percentTats, iceCream])
    classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
    print "you will probably like this person: " , resultList[classifierResult - 1]
    

#影象轉換為向量
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect


#手寫數字識別系統的測試程式碼
def handwritingClassTest():
    hwLabels = []
    trainingFileList = os.listdir('trainingDigits')
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = os.listdir('testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print "the classifier came back with: %d , the real answer is: %d" % (classifierResult, classNumStr)
        if(classifierResult != classNumStr):
            errorCount += 1.0
    print "\nthe total number of errors is : %d" % errorCount
    print "\nthe total error rate is: %f" % (errorCount/float(mTest))
    

#main函式
if __name__ == "__main__" :
    #利用小規模資料進行測試kNN分類器    
    group, labels = createDataSet()
    a = classify0([0,0], group, labels, 3)
    #print a    
    
    #約會網站的測試程式碼
    datingClassTest()
    #約會網站的預測函式
    classifyPerson()

    #手寫數字識別系統的測試程式碼
    handwritingClassTest()