機器學習實戰之k-近鄰演算法(3)---如何視覺化資料
阿新 • • 發佈:2019-01-06
關於視覺化:
《機器學習實戰》書中的一個小錯誤,P22的datingTestSet.txt這個檔案,根據網上的原始碼,應該選擇datingTestSet2.txt這個檔案。主要的區別是最後的標籤,作者原來使用字串‘veryLike’作為標籤,但是Python轉換會出現ValueError: invalid literal for int() with base 10: 'largeDoses'的錯誤。所以改成後面的檔案就可以了。後面直接用1 2 3 代表not like, general like, very like。這個錯誤一開始用百度查不到,改用Google(反向代理),第二個連結就是。哎。。。國內啊。。。
from numpy import * import operator #建立資料集 def createDataSet(): group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) labels = ['A', 'A', 'B', 'B'] return group, labels #根據輸入測試例項進行k-近鄰分類 def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize, 1)) - dataSet sqDiffMat = diffMat ** 2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount = {} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] #處理輸入格式問題,從檔案中讀取資料 def file2matrix(filename, dim2): fr = open(filename) arrayOLines = fr.readlines() numberOfLines = len(arrayOLines) returnMat = zeros((numberOfLines, dim2)) classLabelVector = [] index = 0 for line in arrayOLines: line = line.strip() listFromLine = line.split('\t') returnMat[index, :] = listFromLine[0:dim2] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat, classLabelVector
>>>import kNN
>>> reload(kNN)
<module 'kNN' from 'kNN.pyc'>
>>> datingDataMat, datingLabels = kNN.file2matrix('datingTestSet2.txt', 3)
得到了約會網站的資料之後,我們可以可視化出來。
利用Matplotlib建立散點圖,python(x, y)內嵌這個,直接import即可。
>>> mimport matplotlib.pyplot as plt >>> fig = plt.figure() >>> ax = fig.add_subplot(111) >>> ax.scatter(datingDataMat[:,1], datingDataMat[:,2]) >>> plt.show()
這個是“玩視訊遊戲消耗時間比”和“每週所消費的冰淇淋公升數”的二維圖。
需要不同的顏色來得到更好的視覺化效果。
scatter函式的使用。
ax.scatter(datingDataMat[:,1], datingDataMat[:, 2], 15.0*array(datingLabels), 15.0*array(datingLabels))
這句程式碼替換前面的對應程式碼,如果出現name 'array' is not defined ,請在前面加這句:
from numpy import *
下面是青色,暗紅色,黑色三種點,是利用了15 * datingLabels的1, 2, 3作為不同點的顏色和尺寸。