1. 程式人生 > >2、python機器學習基礎教程——K近鄰演算法鳶尾花分類

2、python機器學習基礎教程——K近鄰演算法鳶尾花分類

一、第一個K近鄰演算法應用:鳶尾花分類

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# 載入資料
iris_dataset = load_iris()

# 例項化模型
knn = KNeighborsClassifier(n_neighbors=1)

# 切分訓練和測試資料集
X_train,X_test,y_train,y_test = train_test_split(iris_dataset["data"], iris_dataset["target"],random_state=0)

#訓練
knn.fit(X_train, y_train)

# 評估模型
print("Test set score:{:.2f}".format(knn.score(X_test,y_test)))

# 預測
X_new = np.array([[5,2.9,1,0.2]])
prediction = knn.predict(X_new)
print("Predicted target name:{}".format(iris_dataset["target_names"][prediction]))

以上程式碼段包含了應用scikit-learn中人和機器學習演算法的核心程式碼。

fit、predict和score方法是scikit-learn監督學習模型中最常用的介面。