AI工程師成長之路-KNN分類演算法實現
阿新 • • 發佈:2018-12-15
本博文是博主在阿里雲大學學習記錄的筆記,未經博主允許禁止隨意轉載。
接下來將學習如何使用Python實現KNN分類演算法。
說明:本實驗的程式檔案與資料在啟動jupyter notebook後,就會在主目錄中顯示,可以直接開啟檢視並執行,但為了增加熟練度,達到最佳的學習效果,建議大家手動輸入。
①匯入程式執行所需的庫。
import numpy as np
import pandas as pd
②讀取鳶尾花資料集,並對資料集進行預處理。
# 讀取鳶尾花資料集,header引數來指定標題的行。預設為0。如果沒有標題,則使用None。 data = pd.read_csv(r"Iris.csv", header=0) # 顯示前n行記錄。預設n的值為5。 # data.head() # 顯示末尾的n行記錄。預設n的值為5。 #data.tail() # 隨機抽取樣本。預設抽取一條,我們可以通過引數進行指定抽取樣本的數量。 # data.sample(10) # 將類別文字對映成為數值型別。 data["Species"] = data["Species"].map({"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2}) # 刪除不需要的Id列。 data.drop("Id", axis=1, inplace=True) # data.duplicated().any() # 檢視資料集的記錄數。 # len(data) # 刪除重複的記錄。 data.drop_duplicates(inplace=True) # len(data) # 檢視各個類別的鳶尾花具有多少條記錄。 data["Species"].value_counts()
③定義KNN類,用於分類。類中定義兩個預測方法,分為考慮權重與不考慮權重兩種情況。
class KNN: """使用Python語言實現K近鄰演算法。(實現分類)""" def __init__(self, k): """初始化方法 Parameters ----- k : int 鄰居的個數。 """ self.k = k def fit(self, X, y): """訓練方法 Parameters ----- X : 類陣列型別,形狀為:[樣本數量, 特徵數量] 待訓練的樣本特徵(屬性) y : 類陣列型別,形狀為: [樣本數量] 每個樣本的目標值(標籤)。 """ # 將X轉換成ndarray陣列型別。 self.X = np.asarray(X) self.y = np.asarray(y) def predict(self, X): """根據引數傳遞的樣本,對樣本資料進行預測。 Parameters ----- X : 類陣列型別,形狀為:[樣本數量, 特徵數量] 待訓練的樣本特徵(屬性) Returns ----- result : 陣列型別 預測的結果。 """ X = np.asarray(X) result = [] # 對ndarray陣列進行遍歷,每次取陣列中的一行。 for x in X: # 對於測試集中的每一個樣本,依次與訓練集中的所有樣本求距離。 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1)) # 返回陣列排序後,每個元素在原陣列(排序之前的陣列)中的索引。 index = dis.argsort() # 進行截斷,只取前k個元素。【取距離最近的k個元素的索引】 index = index[:self.k] # 返回陣列中每個元素出現的次數。元素必須是非負的整數。 count = np.bincount(self.y[index]) # 返回ndarray陣列中,值最大的元素對應的索引。該索引就是我們判定的類別。 # 最大元素索引,就是出現次數最多的元素。 result.append(count.argmax()) return np.asarray(result) def predict2(self, X): """根據引數傳遞的樣本,對樣本資料進行預測(考慮權重的,使用距離的倒數作為權重)。 Parameters ----- X : 類陣列型別,形狀為:[樣本數量, 特徵數量] 待訓練的樣本特徵(屬性) Returns ----- result : 陣列型別 預測的結果。 """ X = np.asarray(X) result = [] # 對ndarray陣列進行遍歷,每次取陣列中的一行。 for x in X: # 對於測試集中的每一個樣本,依次與訓練集中的所有樣本求距離。 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1)) # 返回陣列排序後,每個元素在原陣列(排序之前的陣列)中的索引。 index = dis.argsort() # 進行截斷,只取前k個元素。【取距離最近的k個元素的索引】 index = index[:self.k] # 返回陣列中每個元素出現的次數。元素必須是非負的整數。【使用weights考慮權重,權重為距離的倒數。】 count = np.bincount(self.y[index], weights=1 / dis[index]) # 返回ndarray陣列中,值最大的元素對應的索引。該索引就是我們判定的類別。 # 最大元素索引,就是出現次數最多的元素。 result.append(count.argmax()) return np.asarray(result)
④構建訓練集與測試集,用於對模型進行訓練與測試。
# 提取出每個類比的鳶尾花資料 t0 = data[data["Species"] == 0] t1 = data[data["Species"] == 1] t2 = data[data["Species"] == 2] # 對每個類別資料進行洗牌。 t0 = t0.sample(len(t0), random_state=0) t1 = t1.sample(len(t1), random_state=0) t2 = t2.sample(len(t2), random_state=0) # 構建訓練集與測試集。 train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]], axis=0) train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0) test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0) test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0) # 建立KNN物件,進行訓練與測試。 knn = KNN(k=3) # 進行訓練 knn.fit(train_X, train_y) # 進行測試,獲得測試的結果。 result = knn.predict(test_X) # display(result) # display(test_y) display(np.sum(result == test_y)) display(np.sum(result == test_y)/ len(result))
程式執行結果如下:
26
0.9629629629629629
⑤在考慮權重的情況下,進行測試。
# 考慮權重,進行一下測試。
result2 = knn.predict2(test_X)
display(np.sum(result2 == test_y))
程式執行結果如下:
26
⑥匯入視覺化所必須的庫。
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams["font.family"] = "SimHei"
mpl.rcParams["axes.unicode_minus"] = False
⑦繪製散點圖。為了能夠更方便的進行視覺化,這裡只選擇了兩個維度(分別是花萼長度與花瓣長度)。
# {"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2})
# 設定畫布的大小
plt.figure(figsize=(10, 10))
# 繪製訓練集資料
plt.scatter(x=t0["SepalLengthCm"][:40], y=t0["PetalLengthCm"][:40], color="r", label="Iris-virginica")
plt.scatter(x=t1["SepalLengthCm"][:40], y=t1["PetalLengthCm"][:40], color="g", label="Iris-setosa")
plt.scatter(x=t2["SepalLengthCm"][:40], y=t2["PetalLengthCm"][:40], color="b", label="Iris-versicolor")
# 繪製測試集資料
right = test_X[result == test_y]
wrong = test_X[result != test_y]
plt.scatter(x=right["SepalLengthCm"], y=right["PetalLengthCm"], color="c", marker="x", label="right")
plt.scatter(x=wrong["SepalLengthCm"], y=wrong["PetalLengthCm"], color="m", marker=">", label="wrong")
plt.xlabel("花萼長度")
plt.ylabel("花瓣長度")
plt.title("KNN分類結果顯示")
plt.legend(loc="best")
plt.show()
程式執行結果如下: