kNN(k近鄰)算法代碼實現
阿新 • • 發佈:2019-03-20
通過 預測 3.5 得到 counter 代碼實現 code 統計 args
目標:預測未知數據(或測試數據)X的分類y
批量kNN算法
1.輸入一個待預測的X(一維或多維)給訓練數據集,計算出訓練集X_train中的每一個樣本與其的距離
2.找到前k個距離該數據最近的樣本-->所屬的分類y_train
3.將前k近的樣本進行統計,哪個分類多,則我們將x分類為哪個分類
# 準備階段: import numpy as np # import matplotlib.pyplot as plt raw_data_X = [[3.393533211, 2.331273381], [3.110073483, 1.781539638], [1.343808831, 3.368360954], [3.582294042, 4.679179110], [2.280362439, 2.866990263], [7.423436942, 4.696522875], [5.745051997, 3.533989803], [9.172168622, 2.511101045], [7.792783481, 3.424088941], [7.939820817, 0.791637231] ] raw_data_y= [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] X_train = np.array(raw_data_X) y_train = np.array(raw_data_y) x = np.array([8.093607318, 3.365731514])
核心代碼:
目標:預測未知數據(或測試數據)X的分類y 批量kNN算法 1.輸入一個待預測的X(一維或多維)給訓練數據集,計算出訓練集X_train中的每一個樣本與其的距離 2.找到前k個距離該數據最近的樣本-->所屬的分類y_train 3.將前k近的樣本進行統計,哪個分類多,則我們將x分類為哪個分類from math import sqrt from collections import Counter # 已知X_train,y_train # 預測x的分類 def predict(x, k=5): # 計算訓練集每個樣本與x的距離 distances = [sqrt(np.sum((x-x_train)**2)) for x_train in X_train] # 這裏用了numpy的fancy方法,np.sum((x-x_train)**2) # 獲得距離對應的索引,可以通過這些索引找到其所屬分類y_train nearest = np.argsort(distances) # 得到前k近的分類y topK_y = [y_train[neighbor] for neighbor in nearest[:k]] # 投票的方式,得到一個字典,key是分類,value數個數 votes = Counter(topK_y) # 取出得票第一名的分類 return votes.most_common(1)[0][0] # 得到y_predict predict(x, k=6)
面向對象的方式,模仿sklearn中的方法實現kNN算法:
import numpy as np from math import sqrt from collections import Counter class kNN_classify: def __init__(self, n_neighbor=5): self.k = n_neighbor self._X_train = None self._y_train = None def fit(self, X_train, y_train): self._X_train = X_train self._y_train = y_train return self def predict(self, X): ‘‘‘接收多維數據,返回y_predict也是多維的‘‘‘ y_predict = [self._predict(x) for x in X] # return y_predict return np.array(y_predict) # 返回array的格式 def _predict(self, x): ‘‘‘接收一個待預測的x,返回y_predict‘‘‘ distances = [sqrt(np.sum((x-x_train)**2)) for x_train in self._X_train] nearest = np.argsort(distances) topK_y = [self._y_train[neighbor] for neighbor in nearest[:self.k]] votes = Counter(topK_y) return votes.most_common(1)[0][0] def __repr__(self): return ‘kNN_clf(k=%d)‘ % self.k
kNN(k近鄰)算法代碼實現