1. 程式人生 > >資料探勘領域十大經典演算法之—K-鄰近演算法/kNN(超詳細附程式碼)

資料探勘領域十大經典演算法之—K-鄰近演算法/kNN(超詳細附程式碼)

簡介

又叫K-鄰近演算法,是監督學習中的一種分類演算法。目的是根據已知類別的樣本點集求出待分類的資料點類別。

基本思想

kNN的思想很簡單:在訓練集中選取離輸入的資料點最近的k個鄰居,根據這個k個鄰居中出現次數最多的類別(最大表決規則),作為該資料點的類別。kNN演算法中,所選擇的鄰居都是已經正確分類的物件。

e.g:下圖中,綠色圓要被決定賦予哪個類,是紅色三角形還是藍色四方形?如果k=3,由於紅色三角形所佔比例為2/3,綠色圓將被賦予紅色三角形那個類,如果k=5,由於藍色四方形比例為3/5,因此綠色圓被賦予藍色四方形類。
    image

演算法複雜度

kNN是一種lazy-learning演算法,分類器不需要使用訓練集進行訓練,因此訓練時間複雜度為0;kNN分類的計算複雜度和訓練集中的文件數目成正比,也就是說,如果訓練集中文件總數為n,那麼kNN的分類時間複雜度為O(n);因此,最終的時間複雜度是O(n)。

優缺點

優點

  1. 理論成熟,思想簡單,既可以用來做分類也可以用來做迴歸 ;
  2. 適合對稀有事件進行分類(例如:客戶流失預測);
  3. 特別適合於多分類問題(multi-modal,物件具有多個類別標籤,例如:根據基因特徵來判斷其功能分類), kNN比SVM的表現要好。

缺點

  1. 當樣本不平衡時,如一個類的樣本容量很大,而其他類樣本容量很小時,有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本佔多數;
  2. 計算量較大,因為對每一個待分類的文字都要計算它到全體已知樣本的距離,才能求得它的K個最近鄰點;
  3. 可理解性差,無法給出像決策樹那樣的規則。

程式碼

程式碼已在github

上實現,這裡也貼出來

# coding:utf-8

import numpy as np

def createDataset():
    '''
    建立訓練集,特徵值分別為搞笑鏡頭、擁抱鏡頭、打鬥鏡頭的數量
    '''
    learning_dataset = {"寶貝當家": [45, 2, 9, "喜劇片"],
              "美人魚": [21, 17, 5, "喜劇片"],
              "澳門風雲3": [54, 9, 11, "喜劇片"],
              "功夫熊貓3": [39, 0, 31, "喜劇片"],
              "諜影重重"
: [5, 2, 57, "動作片"], "葉問3": [3, 2, 65, "動作片"], "倫敦陷落": [2, 3, 55, "動作片"], "我的特工爺爺": [6, 4, 21, "動作片"], "奔愛": [7, 46, 4, "愛情片"], "夜孔雀": [9, 39, 8, "愛情片"], "代理情人": [9, 38, 2, "愛情片"], "新步步驚心": [8, 34, 17, "愛情片"]} return learning_dataset def kNN(learning_dataset,dataPoint,k): ''' kNN演算法,返回k個鄰居的類別和得到的測試資料的類別 ''' # s1:計算一個新樣本與資料集中所有資料的距離 disList=[] for key,v in learning_dataset.items(): d=np.linalg.norm(np.array(v[:3])-np.array(dataPoint)) disList.append([key,round(d,2)]) # s2:按照距離大小進行遞增排序 disList.sort(key=lambda dis: dis[1]) # s3:選取距離最小的k個樣本 disList=disList[:k] # s4:確定前k個樣本所在類別出現的頻率,並輸出出現頻率最高的類別 labels = {"喜劇片":0,"動作片":0,"愛情片":0} for s in disList: label = learning_dataset[s[0]] labels[label[len(label)-1]] += 1 labels =sorted(labels.items(),key=lambda asd: asd[1],reverse=True) return labels,labels[0][0] if __name__ == '__main__': learning_dataset=createDataset() testData={"唐人街探案": [23, 3, 17, "?片"]} dataPoint=list(testData.values())[0][:3] k=6 labels,result=kNN(learning_dataset,dataPoint,k) print(labels,result,sep='\n')