機器學習之KNN最鄰近分類演算法
KNN演算法簡介
KNN(K-Nearest Neighbor)最鄰近分類演算法是資料探勘分類(classification)技術中最簡單的演算法之一,其指導思想是”近朱者赤,近墨者黑“,即由你的鄰居來推斷出你的類別。
KNN最鄰近分類演算法的實現原理:為了判斷未知樣本的類別,以所有已知類別的樣本作為參照,計算未知樣本與所有已知樣本的距離,從中選取與未知樣本距離最近的K個已知樣本,根據少數服從多數的投票法則(majority-voting),將未知樣本與K個最鄰近樣本中所屬類別佔比較多的歸為一類。
以上就是KNN演算法在分類任務中的基本原理,實際上K這個字母的含義就是要選取的最鄰近樣本例項的個數,在 scikit-learn 中 KNN演算法的 K 值是通過 n_neighbors 引數來調節的,預設值是 5。
如下圖所示,如何判斷綠色圓應該屬於哪一類,是屬於紅色三角形還是屬於藍色四方形?如果K=3,由於紅色三角形所佔比例為2/3,綠色圓將被判定為屬於紅色三角形那個類,如果K=5,由於藍色四方形比例為3/5,因此綠色圓將被判定為屬於藍色四方形類。
由於KNN最鄰近分類演算法在分類決策時只依據最鄰近的一個或者幾個樣本的類別來決定待分類樣本所屬的類別,而不是靠判別類域的方法來確定所屬類別的,因此對於類域的交叉或重疊較多的待分樣本集來說,KNN方法較其他方法更為適合。
KNN演算法的關鍵:
(1) 樣本的所有特徵都要做可比較的量化
若是樣本特徵中存在非數值的型別,必須採取手段將其量化為數值。例如樣本特徵中包含顏色,可通過將顏色轉換為灰度值來實現距離計算。
(2) 樣本特徵要做歸一化處理
樣本有多個引數,每一個引數都有自己的定義域和取值範圍,他們對距離計算的影響不一樣,如取值較大的影響力會蓋過取值較小的引數。所以樣本引數必須做一些 scale 處理,最簡單的方式就是所有特徵的數值都採取歸一化處置。
(3) 需要一個距離函式以計算兩個樣本之間的距離
通常使用的距離函式有:歐氏距離、餘弦距離、漢明距離、曼哈頓距離等,一般選歐氏距離作為距離度量,但是這是隻適用於連續變數。在文字分類這種非連續變數情況下,漢明距離可以用來作為度量。通常情況下,如果運用一些特殊的演算法來計算度量的話,K近鄰分類精度可顯著提高,如運用大邊緣最近鄰法或者近鄰成分分析法。
以計算二維空間中的A(x1,y1)、B(x2,y2)兩點之間的距離為例,歐氏距離和曼哈頓距離的計算方法如下圖所示:
(4) 確定K的值
K值選的太大易引起欠擬合,太小容易過擬合,需交叉驗證確定K值。
KNN演算法的優點:
1.簡單,易於理解,易於實現,無需估計引數,無需訓練;
2. 適合對稀有事件進行分類;
3.特別適合於多分類問題(multi-modal,物件具有多個類別標籤), kNN比SVM的表現要好。
KNN演算法的缺點:
KNN演算法在分類時有個主要的不足是,當樣本不平衡時,如一個類的樣本容量很大,而其他類樣本容量很小時,有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本佔多數,如下圖所示。該演算法只計算最近的鄰居樣本,某一類的樣本數量很大,那麼或者這類樣本並不接近目標樣本,或者這類樣本很靠近目標樣本。無論怎樣,數量並不能影響執行結果。可以採用權值的方法(和該樣本距離小的鄰居權值大)來改進。
該方法的另一個不足之處是計算量較大,因為對每一個待分類的文字都要計算它到全體已知樣本的距離,才能求得它的K個最近鄰點。
可理解性差,無法給出像決策樹那樣的規則。
KNN演算法實現
要自己動手實現KNN演算法其實不難,主要有以下三個步驟:
算距離:給定待分類樣本,計算它與已分類樣本中的每個樣本的距離;
找鄰居:圈定與待分類樣本距離最近的K個已分類樣本,作為待分類樣本的近鄰;
做分類:根據這K個近鄰中的大部分樣本所屬的類別來決定待分類樣本該屬於哪個分類;
以下是使用Python實現KNN演算法的簡單示例:
import math
import csv
import operator
import random
import numpy as np
from sklearn.datasets import make_blobs
#Python version 3.6.5
# 生成樣本資料集 samples(樣本數量) features(特徵向量的維度) centers(類別個數)
def createDataSet(samples=100, features=2, centers=2):
return make_blobs(n_samples=samples, n_features=features, centers=centers, cluster_std=1.0, random_state=8)
# 載入鳶尾花卉資料集 filename(資料集檔案存放路徑)
def loadIrisDataset(filename):
with open(filename, 'rt') as csvfile:
lines = csv.reader(csvfile)
dataset = list(lines)
for x in range(len(dataset)):
for y in range(4):
dataset[x][y] = float(dataset[x][y])
return dataset
# 拆分資料集 dataset(要拆分的資料集) split(訓練集所佔比例) trainingSet(訓練集) testSet(測試集)
def splitDataSet(dataSet, split, trainingSet=[], testSet=[]):
for x in range(len(dataSet)):
if random.random() <= split:
trainingSet.append(dataSet[x])
else:
testSet.append(dataSet[x])
# 計算歐氏距離
def euclideanDistance(instance1, instance2, length):
distance = 0
for x in range(length):
distance += pow((instance1[x] - instance2[x]), 2)
return math.sqrt(distance)
# 選取距離最近的K個例項
def getNeighbors(trainingSet, testInstance, k):
distances = []
length = len(testInstance) - 1
for x in range(len(trainingSet)):
dist = euclideanDistance(testInstance, trainingSet[x], length)
distances.append((trainingSet[x], dist))
distances.sort(key=operator.itemgetter(1))
neighbors = []
for x in range(k):
neighbors.append(distances[x][0])
return neighbors
# 獲取距離最近的K個例項中佔比例較大的分類
def getResponse(neighbors):
classVotes = {}
for x in range(len(neighbors)):
response = neighbors[x][-1]
if response in classVotes:
classVotes[response] += 1
else:
classVotes[response] = 1
sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)
return sortedVotes[0][0]
# 計算準確率
def getAccuracy(testSet, predictions):
correct = 0
for x in range(len(testSet)):
if testSet[x][-1] == predictions[x]:
correct += 1
return (correct / float(len(testSet))) * 100.0
def main():
# 使用自定義建立的資料集進行分類
# x,y = createDataSet(features=2)
# dataSet= np.c_[x,y]
# 使用鳶尾花卉資料集進行分類
dataSet = loadIrisDataset(r'C:\DevTolls\eclipse-pureh2b\python\DeepLearning\KNN\iris_dataset.txt')
print(dataSet)
trainingSet = []
testSet = []
splitDataSet(dataSet, 0.75, trainingSet, testSet)
print('Train set:' + repr(len(trainingSet)))
print('Test set:' + repr(len(testSet)))
predictions = []
k = 7
for x in range(len(testSet)):
neighbors = getNeighbors(trainingSet, testSet[x], k)
result = getResponse(neighbors)
predictions.append(result)
print('>predicted=' + repr(result) + ',actual=' + repr(testSet[x][-1]))
accuracy = getAccuracy(testSet, predictions)
print('Accuracy: ' + repr(accuracy) + '%')
main()
尾花卉資料檔案百度網盤下載連結:https://pan.baidu.com/s/10vI5p_QuM7esc-jkar2zdQ 密碼:4und