1. 程式人生 > >利用scikit-learn下的knn實現kaggle的手寫數字識別問題

利用scikit-learn下的knn實現kaggle的手寫數字識別問題

# -*- coding:utf-8 -*-
'''
Created on 2017年3月28日

@author: okcing

手寫數字識別
'''
import csv
from sklearn import neighbors

#匯入訓練資料和測試資料
def loadData(filename1,filename2,trainDataSet,trainTargetSet,testDataSet):
    with open(filename1,'r') as csvfile1:
        lines1 = csv.reader(csvfile1)
        dataSet = list(lines1)
        for
x in range(1,len(dataSet)): temp = [] dataSet[x][0] = int(dataSet[x][0]) trainTargetSet.append(dataSet[x][0]) for y in range(1,785): #if dataSet[x][y] != 0: #dataSet[x][y] = 1 dataSet[x][y] = int(dataSet[x][y]) temp.append(dataSet[x][y]) trainDataSet.append(temp) with
open(filename2,'r') as csvfile2: lines2 = csv.reader(csvfile2) dataSet2 = list(lines2) for x in range(1,len(dataSet2)): temp = [] for y in range(784): #if dataSet2[x][y] != 0: #dataSet2[x][y] = 1 dataSet2[x][y] = int(dataSet2[x][y]) temp.append(dataSet2[x][y]) testDataSet.append(temp) return
trainDataSet,trainTargetSet,testDataSet #將結果儲存為csv檔案用於在kaggle網站提交 def saveResult(result): #結果儲存的路徑 with open(r'D:\digit\result.csv','w',newline='') as myFile: myWriter=csv.writer(myFile) x=0 for i in result: x += 1 tmp=[x] tmp.append(i) myWriter.writerow(tmp) def main(): trainDataSet = [] trainTargetSet = [] testDataSet = [] print("開始載入資料") #訓練資料和測試資料的路徑 loadData(r'D:\digit\train.csv', r'D:\digit\test.csv', trainDataSet, trainTargetSet, testDataSet) knn = neighbors.KNeighborsClassifier() print("資料載入完畢,開始訓練模型") knn.fit(trainDataSet,trainTargetSet) print("模型訓練完畢,開始預測") prediction = knn.predict(testDataSet) print("預測結果:", prediction) print("列印完畢,開始儲存") saveResult(prediction) print("儲存完畢") if __name__ == '__main__': main()

整個實現十分簡單,將資料經過處理得到了trainData和trainTarget,用來訓練knn分類器,然後利用分類器對testSet進行預測,將結果儲存。利用sklearn的包進行機器學習確實很方便,對原始資料也不用進行什麼歸一化處理,也不必考慮用的是那種計算距離的方式,如果是剛開始入門,直接使用這個就能很方便的實現knn演算法。不過由於資料量巨大,所以程式執行耗時很長,我在自己的筆記本上大概跑了一個小時左右,載入資料大概5分鐘左右,主要是訓練模型很花時間。