1. 程式人生 > >機器學習實戰——改進約會網站匹配效果

機器學習實戰——改進約會網站匹配效果

接上文,改進約會網站的匹配效果,資料集有四列,分別為:飛行時間,玩遊戲時間,冰淇淋消費,是否為感興趣的約會物件。其中是否為感興趣的約會物件分為三類:不感興趣,有點感興趣和非常感興趣。

def file2matrix(filename):  #讀入文字記錄
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return  
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector
   

 len(fr.readlines()) :獲取整個檔案有多少行

 zeros((numberOfLines,3))   :生成一個空的矩陣,內容都是0,這樣生成二維矩陣,可以明確有幾行幾列

returnMat[index,:]  :表示對returnMat中第index行所有元素按從頭到尾順序賦值,:前後都省略,表示從編號0項開始直到最後一位

 listFromLine[0:3]   :實際上是左閉右開區間,包括0但不包括3

.append :是list中不斷在末尾增加值的方法

這裡主要說明了python中讀檔案和將檔案內容轉化為矩陣

def autoNorm(dataSet):    #資料歸一化
    minVals = dataSet.min(0)   #取最小值
    maxVals = dataSet.max(0) #取最大值
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet)) #建一個和dataSet形狀相同的矩陣
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals, (m,1))
    normDataSet = normDataSet/tile(ranges, (m,1))   #newValue=(oldValue-min

)/(max-min)
    return normDataSet, ranges, minVals

normDataSet = zeros(shape(dataSet)) :建一個和dataSet形狀相同的矩陣,用0填充

這裡的歸一化,也是全部用矩陣處理,比起寫迴圈簡練很多

def datingClassTest():       #計算準確率
    hoRatio = 0.50      #hold out 10%
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
    normMat, ranges, minVals = autoNorm(datingDataMat)  #歸一化
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
        if (classifierResult != datingLabels[i]): errorCount += 1.0   #計算預測錯誤的個數
    print "the total error rate is: %f" % (errorCount/float(numTestVecs))
    print errorCount