1. 程式人生 > >機器學習算法的分類準確度

機器學習算法的分類準確度

dataset 數據 鞏固 digits pan () its dict style

本章,我們使用sklearn自帶的手寫識別的數據集進行計算準確度,進而鞏固之前學的KNN算法。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split


digits = datasets.load_digits()

x = digits.data #獲取特征值
y = digits.target #獲取標記

#將數據分為兩部分,訓練數據和測試數據

x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2)

#指定key值
knn = KNeighborsClassifier(n_neighbors=3)

#進行擬合
knn.fit(x_train,y_train)

y_predict = knn.predict(x_test)

ratio = sum(y_predict==y_test)/len(y_test)

print(ratio)

#當我們不想要預測值的時候,我們可以直接使用knn對象的score函數進行得出準確度
ratio_bak = knn.score(x_test,y_test)
print(ratio_bak)

本節主要是進行之前學的KNN算法進行鞏固,進而為後續的學習打好基礎。

機器學習算法的分類準確度