2、python機器學習基礎教程——K近鄰演算法鳶尾花分類
阿新 • • 發佈:2019-01-04
一、第一個K近鄰演算法應用:鳶尾花分類
import numpy as np from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier # 載入資料 iris_dataset = load_iris() # 例項化模型 knn = KNeighborsClassifier(n_neighbors=1) # 切分訓練和測試資料集 X_train,X_test,y_train,y_test = train_test_split(iris_dataset["data"], iris_dataset["target"],random_state=0) #訓練 knn.fit(X_train, y_train) # 評估模型 print("Test set score:{:.2f}".format(knn.score(X_test,y_test))) # 預測 X_new = np.array([[5,2.9,1,0.2]]) prediction = knn.predict(X_new) print("Predicted target name:{}".format(iris_dataset["target_names"][prediction]))
以上程式碼段包含了應用scikit-learn中人和機器學習演算法的核心程式碼。
fit、predict和score方法是scikit-learn監督學習模型中最常用的介面。