使用k-近鄰演算法改進約會網站的配對效果--學習筆記(python3版本)
本文取自《機器學習實戰》第二章,原始為python2實現,現將程式碼移植到python3,且原始程式碼非常整潔,所以這本書的程式碼很值得學習一下。
k-近鄰演算法概述
工作原理:存在一個樣本資料集合,也稱作訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類的對應關係。輸入沒有標籤的新資料後,將新資料的每個特徵與樣本集中資料對應的特徵進行比較,然後演算法提取樣本集中特徵最相似資料(最近鄰)的分類標籤。
k-近鄰演算法的一般流程
1.收集資料:可以使用任何方法
2.準備資料:距離計算所需要的數值,最好是結構化的資料格式
3.分析資料:可以使用任何方法
4.訓練演算法:此步驟不適於k-近鄰演算法
5.測試演算法:計算錯誤率
6.使用演算法:首先輸入樣本資料和結構化的輸出結果,然後執行k-近鄰演算法判定輸入資料分別屬於哪個分類,最後應用對計算出的分類執行後的處理
實現條件
我是在win7作業系統下實現的,使用pycharm。python3.6是用Anaconda。安裝包是numpy,Matplotlib。
演算法過程
程式1--k-近鄰演算法
def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] #dataSet的行數 diffMat = tile(inX, (dataSetSize,1)) - dataSet #tile()在行上重複dataSetSize次,列上1次 sqDiffMat = diffMat ** 2 sqDistances = sqDiffMat.sum(axis=1)#每一行相加 distances = sqDistances ** 0.5 sortedDistIndicies = distances.argsort() #返回陣列值從小到大的索引 classCount = {} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
函式classify0有四個輸入,inX代表用於待分類的輸入向量,dataSet代表訓練樣本集,labels代表標籤向量,k代表k-近鄰的值,一般k值選擇小於20.本文采用的是歐式距離,計算完待分類的向量和所有的點之間距離後,對資料排序,然後返回資料標籤最大的值。
程式2--將文字記錄轉換為Numpy的解析程式
def file2matrix(filename): fr = open(filename) arrayOLines = fr.readlines() numberOfLines = len(arrayOLines) returnMat = zeros((numberOfLines, 3)) classLabelVector = [] index = 0 for line in arrayOLines: line = line.strip() listFromLine = line.split('\t') returnMat[index, :] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat,classLabelVector
開啟文字檔案觀察,一共四列。最後一列為標籤,如下圖:
40920 |
8.326976 |
0.953952 |
3 |
14488 |
7.153469 |
0.953952 |
2 |
這屬於第二步準備資料。函式需要傳入一個引數,就是資料文字名字。首先開啟,然後一次讀取所有的行。計算出資料總共有多少行,構造一個和樣本資料行數相同,列為3的矩陣。構造標籤列表。然後逐行處理資料,並存入矩陣。這裡處理資料是先去掉空格,然後以\t分隔開。
分析資料
>>> import matplotlib
>>> import matplotlib.pyplot as plt
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111)
>>> ax.scatter(group[:,0], group[:,1], 15*array(labels), array(labels))
>>> plt.show()
如圖:
這裡採用列1和列2的屬性值得到的圖,一般來說都是一次一次的試。這裡省略前面試的過程,直接採用最佳屬性成圖。
程式3--歸一化特徵值
def autoNorm(dataSet):
minVals = dataSet.min(0)#每列最小值
maxVals = dataSet.max(0)#每列最大值
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet / tile(ranges, (m,1))
return normDataSet, ranges, minVals
在函式autoNorm中,將每列的最小值放入minVals中,最大值放入maxVals中,
歸一化公式:
newValue = (oldValue - min) / (max - min)
這裡採用的是線性函式歸一化,將原始資料歸一化到[0,1]之間。這樣消除特徵值之間的量綱的差距。比如在此文中,航程一般都是幾千,而消耗的冰淇淋之類一般也就是10以下,如果不歸一化處理,航程對分類結果的影響會非常的大,而歸一化之後,大家佔的比重都差不多。採用線性函式歸一化跟選擇的距離有關。這裡是歐式距離。還有0均值標準化歸一。採用不同的距離測量方法在具體考慮不同的歸一化方法。
程式4--分類器針對約會網站的測試程式碼
def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m,:],\
datingLabels[numTestVecs:m],3)
print("the classifier came back with: %d,the real anser is: %d"\
% (classifierResult, datingLabels[i]))
if (classifierResult != datingLabels[i]):
errorCount += 1.0
print('the total error rate is : %f ' % (errorCount/float(numTestVecs)))
這裡其實是步驟5,測試演算法。也叫交叉驗證,一般用來評判分類器的效能。
函式datingClassTest()函式,先定義用於交叉驗證的資料比率。然後讀取資料樣本,再用autoNorm將資料樣本歸一化。在取得資料樣本的行數。在將具體要作為交叉驗證的資料樣本值存入numTestVecs中,這裡將資料樣本的前numTestVecs個樣本逐一讀取,然後運用k-近鄰演算法得到演算法判定的標籤,再跟真實標籤做比較。
一般來說交叉驗證的資料都是隨機取,若人為干預太多則會對分類器的效能判斷失誤。這裡還可以取最後的一段資料來判定。
程式5--約會網站預測函式
def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(input("percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per year?"))
iceCream = float(input("liters of ice cream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffMiles, percentTats, iceCream])
classifierResult = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3)
print("you will probably like this person: ", resultList[classifierResult - 1])
這段程式主要用來預測,做k-近鄰分類器不可能只是對已有的資料進行分類,最主要的就是用來預測沒有得到標籤的樣本資料。而預測結果的真實性,則由剛才的交叉驗證的結果來評估。如果剛才交叉驗證得到分類器的效能特別的差,那麼就需要調整分類演算法,或者觀察訓練樣本資料的特徵。
預測實驗結果如下圖:
總結
k-近鄰算是比較簡單好用的分類演算法了,也是這本書的第一個演算法。它具有
優點:精度高,對異常值不敏感,無資料輸入假定
缺點:計算複雜度高,空間複雜度高其實我們是否可以嘗試犧牲一定的精度來降低k-近鄰的計算複雜度。就比如說,此文章裡面,分成了三類。我將已有的資料樣本分別計算三類標籤的中心值,預測新的資料標籤時,我就計算新進來的資料與三個樣本中心值距離,就將資料劃分到離它最近的那個中心值那一組。這樣就變成了1-近鄰分類,這樣計算將會降低很多。如果需要提高精度,則將k值取大,但是k值取大之後計算量增加了無數倍。
k-近鄰演算法必須先對資料分類,然後才能預測。不像其他分類演算法,是先訓練樣本。k-近鄰學習起來簡單易懂。