1. 程式人生 > >機器學習基本演算法:感知機

機器學習基本演算法:感知機

import numpy as np
import matplotlib.pyplot as plt

# initialization
n = 2
l = 100

x1 = np.random.randint(0, 100, size=(l, n))
x2 = np.random.randint(-100, 0, size=(l, n))
x = np.array((x1, x2)).reshape(2*l, n)
y = np.array((np.ones((l, 1)), np.ones((l, 1))*-1)).reshape(2*l, 1)
w = np.ones((n, 1))*(-1000)
b = 1
base = np.arange(0, len(y))[:, None]   # indicator for incorrect points
learning_rate = 5

# SGD
condition = (np.dot(x, w) + b) * y           # 每次迭代對全部點計算是否正確分類
flag = base[condition < 0]                   # 布林索引找到所有未正確分類點
while flag.size > 0:
    t = y[flag[0]]*x[flag[0]] * learning_rate      # 任選一個點進行更新
    w = w + t[:, None]
    b += y[flag[0]] * learning_rate
    condition = (np.dot(x, w) + b) * y         # 更新後再次計算,直至沒有誤分類點
    flag = base[condition < 0]

f1 = plt.figure(1)
p1 = plt.scatter(x[:l, 0], x[:l, 1], marker='x', color='m', label='1', s=30)
p2 = plt.scatter(x[l:, 0], x[l:, 1], marker='o', color='c', label='-1', s=15)
px1 = np.arange(-100,100)
px2 = (-b-w[0]*px1)/w[1]
p3 = plt.plot(px1, px2)


# dual form
a = np.zeros((2*l, 1))
gram = np.dot(x, x.T)       # x的內積
condition = y * (np.dot(gram.T, a * y) + b)
flag = base[condition < 0]
while flag.size > 0:
    a[flag[0]] += learning_rate
    b += y[flag[0]] * learning_rate
    condition = y * (np.dot(gram.T, a * y) + b)
    flag = base[condition < 0]

x_re = list(map(lambda x,a,y: x*a*y, x, a, y))   # 這裡傳入了x的每一行,即每個觀察值,與對應的a和y(常數)相乘
w = sum(x_re)                                     # 累加和為最終的w

print(w)
print(b)

# 列印散點圖和wx+b=0
f2 = plt.figure(2)
p4 = plt.scatter(x[:l, 0], x[:l, 1], marker='x', color='m', label='1', s=30)
p5 = plt.scatter(x[l:, 0], x[l:, 1], marker='o', color='c', label='-1', s=15)
px4 = np.arange(-100,100)
px5 = (-b-w[0]*px1)/w[1]
p6 = plt.plot(px1, px2)

plt.show()