1. 程式人生 > >神經網路-手寫字型識別

神經網路-手寫字型識別

3層神經網路,自定義輸入節點、隱藏層、輸出節點的個數,使用sigmoid函式作為啟用函式,梯度下降法進行權重的優化。

使用MNIST資料集,進行手寫數字識別

  1 #!/usr/bin/env python
  2 # -*- coding:utf-8 -*-
  3 
  4 #!/usr/bin/env python
  5 # -*- coding:utf-8 -*-
  6 
  7 import numpy
  8 import scipy.special
  9 
 10 
 11 #手寫數字識別神經網路
 12 class NeuralNetwork():
 13     def __init__
(self,inputnodes,hiddennodes,outputnodes,learningrate): 14 ''' 15 神經網路初始化 16 :param inputnodes: 輸入節點的數量 17 :param hiddennodes: 隱藏層節點的數量 18 :param outputnodes: 輸出節點的數量 19 :param learningrate: 學習率 20 :return: 21 ''' 22 self.inodes = inputnodes
23 self.hnodes = hiddennodes 24 self.onodes = outputnodes 25 self.learn = learningrate 26 self.wih = numpy.random.rand(self.hnodes,self.inodes) - 0.5 27 self.who = numpy.random.rand(self.onodes,self.hnodes) - 0.5 28 # self.wih = numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.inodes,self.inodes))
29 # self.who = numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.hnodes,self.hnodes)) 30 self.activate_function = lambda x : scipy.special.expit(x) 31 # print(self.who) 32 # print(self.wih) 33 def train(self,input_list,target_list): 34 ''' 35 訓練神經網路首先計算樣本輸出,然後在與目標值進行對比,更新權重 36 :param input_list: 輸入值 37 :param target_list: 目標值 38 :return: 39 ''' 40 #針對樣本計算輸出,與query函式一樣 41 inputs = numpy.array(input_list).T 42 targets = numpy.array(target_list).T 43 hidden_inputs = numpy.dot(self.wih,inputs) 44 hidden_outputs = self.activate_function(hidden_inputs) 45 final_inputs = numpy.dot(self.who,hidden_outputs) 46 final_outpust = self.activate_function(final_inputs) 47 48 #將計算得到的輸出與目標值對比,更新權重 49 output_error = targets - final_outpust 50 hidden_error = numpy.dot(self.who.T,output_error) 51 52 # print(output_error.shape) 53 # print(final_outpust.shape) 54 # print(hidden_outputs.T.shape) 55 # self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)),numpy.transpose(hidden_outputs)) 56 # self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs)) 57 58 self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)).reshape((self.onodes,1)),hidden_outputs.reshape((1,self.hnodes))) 59 self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)).reshape((self.hnodes,1)),inputs.reshape((1,self.inodes))) 60 61 62 63 def query(self,input_list): 64 ''' 65 計算輸出 66 :param input_list: 67 :return: 68 ''' 69 inputs = numpy.array(input_list).T 70 hidden_inputs = numpy.dot(self.wih,inputs) 71 hidden_outputs = self.activate_function(hidden_inputs) 72 final_inputs = numpy.dot(self.who,hidden_outputs) 73 final_outpust = self.activate_function(final_inputs) 74 75 return final_outpust 76 77 #初始化一個神經網路物件 78 n = NeuralNetwork(784,100,10,0.5) 79 80 #訓練資料 81 with open('dataset/mnist_train.csv','r') as f: 82 train_data = f.readlines() 83 84 #訓練神經網路 85 for line in train_data: 86 data = line.split(',') 87 inputs = (numpy.asfarray(data[1:]) / 255 * 0.99) + 0.01 88 targets = numpy.zeros(n.onodes)+0.01 89 targets[int(data[0])] = 0.99 90 91 n.train(inputs,targets) 92 93 94 #測試神經網路 95 with open('dataset/mnist_test_10.csv','r') as f: 96 test_data = f.readlines() 97 98 for line in test_data: 99 label = int(line[0]) 100 data = line.split(',') 101 input_list = numpy.asfarray(data[1:]) 102 output = n.query(input_list) 103 104 print(label) 105 print(output)

程式碼實現了手寫數字的識別,可以在此基礎上,進行改進研究,比如調節學習率、初始化權重的方式,啟用函式等變化時對結果的影響。