1. 程式人生 > >【機器學習】感知機Python程式碼實現

【機器學習】感知機Python程式碼實現

回顧

感知機

前面我們介紹了感知機,它是一個二分類的線性分類器,輸入為特徵向量,輸出為例項的類別。感知機演算法利用隨機梯度下降法對基於誤分類的損失函式進行最優化求解,得到感知機模型,即求解w,b。感知機演算法簡單易於實現,那麼我們如何通過python程式碼來實現呢?

接下來我們通過對我們給定的資料進行訓練,得到最終的w,b,並將其視覺化。

Python實現

import copy
from matplotlib import pyplot as plt
from matplotlib import animation

training_set = [[(1, 2
), 1], [(2, 3), 1], [(3, 1), -1], [(4, 2), -1]] # 訓練資料集 w = [0, 0] # 引數初始化 b = 0 history = [] # 用來記錄每次更新過後的w,b def update(item): """ 隨機梯度下降更新引數 :param item: 引數是分類錯誤的點 :return: nothing 無返回值 """ global w, b, history # 把w, b, history宣告為全域性變數 w[0] += 1 * item[1] * item[0][0] # 根據誤分類點更新引數,這裡學習效率設為1
w[1] += 1 * item[1] * item[0][1] b += 1 * item[1] history.append([copy.copy(w), b]) # 將每次更新過後的w,b記錄在history陣列中 def cal(item): """ 計算item到超平面的距離,輸出yi(w*xi+b) (我們要根據這個結果來判斷一個點是否被分類錯了。如果yi(w*xi+b)>0,則分類錯了) :param item: :return: """ res = 0 for i in range(len(item[0
])): # 迭代item的每個座標,對於本文資料則有兩個座標x1和x2 res += item[0][i] * w[i] res += b res *= item[1] # 這裡是乘以公式中的yi return res def check(): """ 檢查超平面是否已將樣本正確分類 :return: true如果已正確分類則返回True """ flag = False for item in training_set: if cal(item) <= 0: # 如果有分類錯誤的 flag = True # 將flag設為True update(item) # 用誤分類點更新引數 if not flag: # 如果沒有分類錯誤的點了 print("最終結果: w: " + str(w) + "b: " + str(b)) # 輸出達到正確結果時引數的值 return flag # 如果已正確分類則返回True,否則返回False if __name__ == "__main__": for i in range(1000): # 迭代1000遍 if not check(): break # 如果已正確分類,則結束迭代 # 以下程式碼是將迭代過程視覺化 # 首先建立我們想要做成動畫的影象figure, 座標軸axis,和plot element fig = plt.figure() ax = plt.axes(xlim=(0, 2), ylim=(-2, 2)) line, = ax.plot([], [], 'g', lw=2) # 畫一條線 label = ax.text([], [], '') def init(): line.set_data([], []) x, y, x_, y_ = [], [], [], [] for p in training_set: if p[1] > 0: x.append(p[0][0]) # 存放yi=1的點的x1座標 y.append(p[0][1]) # 存放yi=1的點的x2座標 else: x_.append(p[0][0]) # 存放yi=-1的點的x1座標 y_.append(p[0][1]) # 存放yi=-1的點的x2座標 plt.plot(x, y, 'bo', x_, y_, 'rx') # 在圖裡yi=1的點用點表示,yi=-1的點用叉表示 plt.axis([-6, 6, -6, 6]) # 橫縱座標上下限 plt.grid(True) # 顯示網格 plt.xlabel('x1') # 這裡我修改了原文表示 plt.ylabel('x2') # 為了和原理中表達方式一致,橫縱座標應該是x1,x2 plt.title('Perceptron Algorithm (www.hankcs.com)') # 給圖一個標題:感知機演算法 return line, label def animate(i): global history, ax, line, label w = history[i][0] b = history[i][1] if w[1] == 0: return line, label # 因為圖中座標上下限為-6~6,所以我們在橫座標為-7和7的兩個點之間畫一條線就夠了,這裡程式碼中的xi,yi其實是原理中的x1,x2 x1 = -7 y1 = -(b + w[0] * x1) / w[1] x2 = 7 y2 = -(b + w[0] * x2) / w[1] line.set_data([x1, x2], [y1, y2]) # 設定線的兩個點 x1 = 0 y1 = -(b + w[0] * x1) / w[1] label.set_text(history[i]) label.set_position([x1, y1]) return line, label print("引數w,b更新過程:", history) anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history), interval=1000, repeat=True, blit=True) plt.show()

執行結果

最終結果: w: [-3, 4]b: 1
引數w,b更新過程: [[[1, 2], 1], [[-2, 1], 0], [[-1, 3], 1], [[-4, 2], 0], [[-3, 4], 1]]

這裡寫圖片描述