1. 程式人生 > >python簡單實現感知機

python簡單實現感知機

def f(x):
    ans=w[0]*x[0]+w[1]*x[1]+b
    return ans
def perception(data_set):
    global b
    flag=True
    for data in data_set:
        if data[1]*f(data[0])<=0:
            w[0]+=eta*data[0][0]*data[1]
            w[1]+=eta*data[0][1]*data[1]
            b+=eta*data[1]
            flag=False
    if
flag==False: perception(data_set) data_set=[[[3,3],1],[[4,3],1],[[1,1],-1]] eta=1 w=[0,0] b=0 perception(data_set) print(w,b)

別人的版本學習學習:

import copy
from matplotlib import pyplot as plt 
from matplotlib import animation
training_set=[[[3,3],1],[[4,3],1],[[1,1],-1]]
w=[0,0]
b=0
history=[]

def
update(item):
# 更新引數 global w,b,history w[0]+=1*item[1]*item[0][0] w[1]+=1*item[1]*item[0][1] b+=1*item[1] print(w,b) history.append([copy.copy(w),b]) def cal(item): #計算函式值yi(w*x+b) res=0 for i in range(len(item[0])): res+=item[0][i]*w[i] res+=b res*=item[1
] return res def check(): #檢查是否需要繼續更新 flag=False for item in training_set: if cal(item)<=0: flag=True update(item) if not flag: print("RESULT:w:"+str(w)+"b:"+str(b)) return flag if __name__=="__main__": for i in range(1000): if not check():break fig=plt.figure() ax=plt.axes(xlim=(0,2),ylim=(-2,2)) line,=ax.plot([],[],'g',lw=2) label=ax.text([],[],'') def init(): line.set_data([],[]) x,y,x_,y_,=[],[],[],[] for p in training_set: if p[1]>0: x.append(p[0][0]) y.append(p[0][1]) else: x_.append(p[0][0]) y_.append(p[0][1]) plt.plot(x,y,'bo',x_,y_,'r^') plt.axis([-6,6,-6,6]) plt.grid() plt.xlabel('x1') plt.ylabel('x2') plt.title('Perception') return line,label def animate(i): global history,ax,line,label w=history[i][0] b=history[i][1] if w[1]==0:return line,label x1 = -7.0 y1 = -(b + w[0] * x1) / w[1] x2 = 7.0 y2 = -(b + w[0] * x2) / w[1] line.set_data([x1, x2], [y1, y2]) x1 = 0.0 y1 = -(b + w[0] * x1) / w[1] label.set_text(str(history[i][0]) + ' ' + str(b)) label.set_position([x1, y1]) return line, label # call the animator. blit=true means only re-draw the parts that have changed. anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history), interval=1000, repeat=True, blit=True) plt.show()