Python實現k-近鄰演算法
阿新 • • 發佈:2019-02-13
1、分類器實現
import numpy as np import operator def KNN(inx, dataset, labels, k): '''inx:待測元素;dataset:已知資料集(Numpy格式);labels:已知資料集元素對應的類別; k:與待測點距離最近的點的個數''' '''距離計算(歐幾里德距離)''' datasetsize = dataset.shape(0) diffmat = np.tile(inx, (datasetsize, 1)) - dataset sqdiffmat = diffmat**2 sqdistances = sqdiffmat.sum(axis=1) distances = sqdistances**0.5 '''選擇距離最小的k個點''' sortlist = distances.argsort() classcount = {} for i in range(k): votelabel = labels[sortlist[i]] classcount[votelabel] = classcount.get(votelabel,0) + 1 #生成各個標籤的直方圖 '''降序排序''' sortclasscount = sorted(classcount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortclasscount[0][0]
2、文件處理,將tidy data(資料清洗部分本文不做說明):
3、資料歸一化,將不同取值範圍的特徵值轉化為0到1的範圍,也可以根據不同權重進行分配,以下為統一轉化為0到1範圍的程式碼:def file2matrix(filename): fr = open(filename) alines = fr.readlines() lines_num = len(alines) returnmat = np.zeros((lines_num, 3)) labelsvector = [] index = 0 for line in alines: line_list = line.strip().split('\t') returnmat[index] = line_list[0:3] labelsvector.append(int(line_list[-1])) index += 1 return returnmat, labelsvector
def newvalue(dataset): minval = dataset.min(0) maxval = dataset.max(0) ranges = maxval - minval newdataset = np.zeros(np.shape(dataset)) m = dataset.shape[0] newdataset = dataset - np.tile(minval, (m,1)) newdataset = newdataset/np.tile(ranges, (m,1)) return newdataset, ranges, minval newdataset, ranges, minval = newvalue(returnmat)
4、分類器檢測,kNN演算法是機器學習中最簡單的,對訓練樣本的預處理要求較高,錯誤率也通常較高,需要進行檢測後才能使用分類器中的訓練樣本:
def datingClassTest():
hoRatio = 0.10 #hold out 10%
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file
newdataset, ranges, minval = newvalue(datingDataMat)
m = newdataset.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = KNN(newdataset[i,:],newdataset[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is: %f" % (errorCount/float(numTestVecs))
print errorCount