1. 程式人生 > >[機器學習筆記]kNN進鄰演算法

[機器學習筆記]kNN進鄰演算法

K-近鄰演算法

一、演算法概述

(1)採用測量不同特徵值之間的距離方法進行分類

  • 優點: 精度高、對異常值不敏感、無資料輸入假定。
  • 缺點: 計算複雜度高、空間複雜度高。

(2)KNN模型的三個要素

kNN演算法模型實際上就是對特徵空間的的劃分。模型有三個基本要素:距離度量、K值的選擇和分類決策規則的決定。

  • 距離度量

    距離定義為:
    \[L_p(x_i,x_j)=\left( \sum^n_{l=1} |x_i^{(l)} - x_j^{(l)}|^p \right) ^{\frac{1}{p}}\]
    一般使用歐式距離:p = 2的個情況
    \[L_p(x_i,x_j)=\left( \sum^n_{l=1} |x_i^{(l)} - x_j^{(l)}|^2 \right) ^{\frac{1}{2}}\]

  • K值的選擇

    一般根據經驗選擇,需要多次選擇對比才可以選擇一個比較合適的K值。

    如果K值太小,會導致模型太複雜,容易產生過擬合現象,並且對噪聲點非常敏感。

    如果K值太大,模型太過簡單,忽略的大部分有用資訊,也是不可取的。

  • 分類決策規則

    一般採用多數表決規則,通俗點說就是在這K個類別中,哪種類別最後就判別為哪種型別


二、實施kNN演算法

2.1 虛擬碼

  • 計演算法已經類別資料集中的點與當前點之間的距離
  • 按照距離遞增次序排序
  • 選取與但前點距離最小的k個點
  • 確定前k個點所在類別的出現頻率
  • 返回前k個點出現頻率最高的類別作為當前點的預測分類


2.2 實際程式碼

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


三、實際案例:使用kNN演算法改進約會網站的配對效果

我的朋友阿J一直使用線上約會軟體尋找約會物件,他曾經交往過三種類型的人:

  • 不喜歡的人
  • 感覺一般的人
  • 非常喜歡的人

步驟:

  • 收集資料
  • 準備資料:也就是讀取資料的過程
  • 分析資料:使用Matplotlib畫出二維散點圖
  • 訓練演算法
  • 測試演算法
  • 使用演算法


3.1 準備資料

樣本資料共有1000個,3個特徵值,共有4列資料,最後一列表示標籤分類(0:不喜歡的人;1:感覺一般的人;2:非常喜歡的人)

特徵

  • 每年獲得的飛行常客里程數
  • 玩視訊遊戲所好的時間百分比
  • 每週消費的冰淇淋公斤數

部分資料如下:

40920   8.326976    0.953952    3
14488   7.153469    1.673904    2
26052   1.441871    0.805124    1
75136   13.147394   0.428964    1
38344   1.669788    0.134296    1
72993   10.141740   1.032955    1
35948   6.830792    1.213192    3
42666   13.276369   0.543880    3
67497   8.631577    0.749278    1
35483   12.273169   1.508053    3

讀取資料(讀取txt檔案)

def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return   
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector


3.2 分析資料:使用Matplotlib建立散點圖

初步分析
import matplotlib
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']

fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
ax.set_xlabel("玩視訊遊戲所耗時間百分比")
ax.set_ylabel("每週消費的冰淇淋公斤數")
plt.show()


因為有三種類型的分類,這樣看的不直觀,我們新增以下顏色

fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels))
ax.set_xlabel("玩視訊遊戲所耗時間百分比")
ax.set_ylabel("每週消費的冰淇淋公斤數")
plt.show()


通過都多次的嘗試後發現,玩遊戲時間和冰淇淋這個兩個特徵關係比較明顯

具體的步驟:

  • 分別將標籤為1,2,3的三種類型的資料分開
  • 使用matplotlib繪製,並使用不同的顏色加以區分
datingDataType1 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==1])
datingDataType2 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==2])
datingDataType3 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==3])
                   

fig, axs = plt.subplots(2, 2, figsize = (15,10))
axs[0,0].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
axs[0,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
axs[1,0].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
type1 = axs[1,1].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
type2 = axs[1,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
type3 = axs[1,1].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
axs[1,1].legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2)
axs[1,1].set_xlabel("玩視訊遊戲所耗時間百分比")
axs[1,1].set_ylabel("每週消費的冰淇淋公斤數")

plt.show()


3.3 準備資料:資料歸一化

通過上面的圖形繪製,發現三個特徵值的範圍不一樣,在使用KNN進行計算距離的時候,數值大的特徵值就會對結果產生更大的影響。

資料歸一化:就是將幾組不同範圍的資料,轉換到同一個範圍內。

公式: newValue = (oldValue - min)/(max - min)

def autoNorm(dataSet):
    minVals = dataSet.min(0) # array([[1,20,3], [4,5,60], [7,8,9]])   min(0) = [1, 5, 3]
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normData = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normData = (dataSet - tile(minVals, (m,1)))/tile(ranges,(m,1))
    return normData


3.4 測試演算法

我們將原始樣本保留20%作為測試集,剩餘80%作為訓練集

def datingClassTest():
    hoRatio = 0.20  
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
    normMat = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:,:],datingLabels[numTestVecs:],3)
        if (classifierResult != datingLabels[i]): 
            errorCount += 1.0
    print ("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    print (errorCount)

執行結果

the total error rate is: 0.080000
16.0


四、原始碼

from numpy import *
import operator
from os import listdir

import matplotlib
import matplotlib.pyplot as plt
    
## KNN function
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

# read txt data
def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return   
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector


def autoNorm(dataSet):
    minVals = dataSet.min(0) # array([[1,20,3], [4,5,60], [7,8,9]])   min(0) = [1, 5, 3]
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normData = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normData = (dataSet - tile(minVals, (m,1)))/tile(ranges,(m,1))
    return normData
    
    
    
    
def drawScatter1(datingDataMat, datingLabels):
    plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
    ax.set_xlabel("玩視訊遊戲所耗時間百分比")
    ax.set_ylabel("每週消費的冰淇淋公斤數")
    plt.show()
    
def drawScatter2(datingDataMat, datingLabels):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
    ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels))
    ax.set_xlabel("玩視訊遊戲所耗時間百分比")
    ax.set_ylabel("每週消費的冰淇淋公斤數")
    plt.show()
    
    
def drawScatter3(datingDataMat, datingLabels):
    datingDataType1 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==1])
    datingDataType2 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==2])
    datingDataType3 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==3])

    fig, axs = plt.subplots(2, 2, figsize = (15,10))
    axs[0,0].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
    axs[0,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
    axs[1,0].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
    type1 = axs[1,1].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
    type2 = axs[1,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
    type3 = axs[1,1].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
    axs[1,1].legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2)
    axs[1,1].set_xlabel("玩視訊遊戲所耗時間百分比")
    axs[1,1].set_ylabel("每週消費的冰淇淋公斤數")

    plt.show()
    
    
    
def datingClassTest():
    hoRatio = 0.20  
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
    normMat = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:,:],datingLabels[numTestVecs:],3)
        if (classifierResult != datingLabels[i]): 
            errorCount += 1.0
    print ("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    print (errorCount)
    
    
datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")

drawScatter1(datingDataMat, datingLabels)
drawScatter2(datingDataMat, datingLabels)
drawScatter3(datingDataMat, datingLabels)
 
datingClassTest()