1. 程式人生 > >Python實戰之神經網路(1)

Python實戰之神經網路(1)



python有專門的神經網路庫,但為了加深印象,我自己在numpy庫的基礎上,自己編寫了一個簡單的神經網路程式,是基於Rosenblatt感知器的,這個感知器建立在一個線性神經元之上,神經元模型的求和節點計算作用於突觸輸入的線性組合,同時結合外部作用的偏置,對若干個突觸的輸入求和後進行調節。為了便於觀察,這裡的資料採用二維資料。

目標函式是訓練結果的誤差的平方和,由於目標函式是一個二次函式,只存在一個全域性極小值,所以採用梯度下降法的策略尋找目標函式的最小值。

程式碼如下:

[python] view plaincopyprint?在CODE上檢視程式碼片派生到我的程式碼片
  1. import numpy as np  
  2. import
     pylab as pl  
  3. b=1#偏置
  4. a=0.3#學習率
  5. x=np.array([[b,1,3],[b,2,3],[b,1,8],[b,2,15],[b,3,7],[b,4,29],[b,4,8],[b,4,20]])   #訓練資料
  6. d=np.array([1,1,-1,-1,1,-1,1,-1])                                      #訓練資料類別
  7. w=np.array([b,0,0])                                               #初始w
  8. def sgn(v):                                   
  9.     if v>=0:  
  10.         return1
  11.     else:  
  12.         return -1
  13. def comy(myw,myx):  
  14.     return sgn(np.dot(myw.T,myx))  
  15. def neww(oldw,myd,myx,a):  
  16.     return oldw+a*(myd-comy(oldw,myx))*myx  
  17. for ii in range(5):                                #迭代次數
  18.     i=0
  19.     for xn in x:  
  20.         w=neww(w,d[i],xn,a)  
  21.         i+=1
  22.     print w  
  23. myx=x[:,1]                                    #繪製訓練資料
  24. myy=x[:,2]  
  25. pl.subplot(111)  
  26. x_max=np.max(myx)+15
  27. x_min=np.min(myx)-5
  28. y_max=np.max(myy)+50
  29. y_min=np.min(myy)-5
  30. pl.xlabel(u"x")  
  31. pl.xlim(x_min,x_max)  
  32. pl.ylabel(u"y")  
  33. pl.ylim(y_min,y_max)  
  34. for i in range(0,len(d)):  
  35.     if d[i]==1:  
  36.         pl.plot(myx[i],myy[i],'r*')  
  37.     else:  
  38.         pl.plot(myx[i],myy[i],'ro')  
  39. #繪製測試點
  40. test=np.array([b,9,19])  
  41. if comy(w,test)>0:  
  42.     pl.plot(test[1],test[2],'b*')  
  43. else:  
  44.     pl.plot(test[1],test[2],'bo')  
  45. test=np.array([b,9,64])  
  46. if comy(w,test)>0:  
  47.     pl.plot(test[1],test[2],'b*')  
  48. else:  
  49.     pl.plot(test[1],test[2],'bo')  
  50. test=np.array([b,9,16])  
  51. if comy(w,test)>0:  
  52.     pl.plot(test[1],test[2],'b*')  
  53. else:  
  54.     pl.plot(test[1],test[2],'bo')  
  55. test=np.array([b,9,60])  
  56. if comy(w,test)>0:  
  57.     pl.plot(test[1],test[2],'b*')  
  58. else:  
  59.     pl.plot(test[1],test[2],'bo')  
  60. #繪製分類線
  61. testx=np.array(range(0,20))  
  62. testy=testx*2+1.68
  63. pl.plot(testx,testy,'g--')  
  64. pl.show()     
  65. for xn in x:  
  66.     print"%d  %d => %d" %(xn[1],xn[2],comy(w,xn))  
import numpy as np
import pylab as pl
b=1              #偏置
a=0.3            #學習率
x=np.array([[b,1,3],[b,2,3],[b,1,8],[b,2,15],[b,3,7],[b,4,29],[b,4,8],[b,4,20]])   #訓練資料
d=np.array([1,1,-1,-1,1,-1,1,-1])                                      #訓練資料類別
w=np.array([b,0,0])                                               #初始w
def sgn(v):                                 
    if v>=0:
        return 1
    else:
        return -1
def comy(myw,myx):
    return sgn(np.dot(myw.T,myx))
def neww(oldw,myd,myx,a):
    return oldw+a*(myd-comy(oldw,myx))*myx

for ii in range(5):                                #迭代次數
    i=0
    for xn in x:
        w=neww(w,d[i],xn,a)
        i+=1
    print w

myx=x[:,1]                                    #繪製訓練資料
myy=x[:,2]
pl.subplot(111)
x_max=np.max(myx)+15
x_min=np.min(myx)-5
y_max=np.max(myy)+50
y_min=np.min(myy)-5
pl.xlabel(u"x")
pl.xlim(x_min,x_max)
pl.ylabel(u"y")
pl.ylim(y_min,y_max)
for i in range(0,len(d)):
    if d[i]==1:
        pl.plot(myx[i],myy[i],'r*')
    else:
        pl.plot(myx[i],myy[i],'ro')
#繪製測試點
test=np.array([b,9,19])
if comy(w,test)>0:
    pl.plot(test[1],test[2],'b*')
else:
    pl.plot(test[1],test[2],'bo')
test=np.array([b,9,64])
if comy(w,test)>0:
    pl.plot(test[1],test[2],'b*')
else:
    pl.plot(test[1],test[2],'bo')
test=np.array([b,9,16])
if comy(w,test)>0:
    pl.plot(test[1],test[2],'b*')
else:
    pl.plot(test[1],test[2],'bo')
test=np.array([b,9,60])
if comy(w,test)>0:
    pl.plot(test[1],test[2],'b*')
else:
    pl.plot(test[1],test[2],'bo')
#繪製分類線
testx=np.array(range(0,20))
testy=testx*2+1.68
pl.plot(testx,testy,'g--')
pl.show()   
for xn in x:
    print "%d  %d => %d" %(xn[1],xn[2],comy(w,xn))


圖中紅色是訓練資料,藍色是測試資料,圓點代表類別-1.星點代表類別1。由圖可知,對於線性可分的資料集,Rosenblatt感知器的分類效果還是不錯的