1. 程式人生 > >python線上神經網路實現手寫字元識別系統

python線上神經網路實現手寫字元識別系統

               

神經網路實現手寫字元識別系統

一、課程介紹

1. 課程來源

課程內容在原文件基礎上做了稍許修改,增加了部分原理介紹,步驟的拆解分析及原始碼註釋。

2. 內容簡介

本課程最終將基於BP神經網路實現一個手寫字元識別系統,系統會在伺服器啟動時自動讀入訓練好的神經網路檔案,如果檔案不存在,則讀入資料集開始訓練,使用者可以通過在html頁面上手寫數字傳送給伺服器來得到識別結果。

3. 課程知識點

本課程專案完成過程中,我們將學習:

  1. 什麼是神經網路
  2. 在客戶端(瀏覽器)完成手寫資料的輸入與請求的傳送
  3. 在伺服器端根據請求呼叫神經網路模組並給出響應
  4. 實現BP神經網路

二、實驗環境

開啟終端,進入 Code 目錄,建立 ocr 資料夾, 並將其作為我們的工作目錄。

$ cd Code$ mkdir ocr && cd ocr

三、實驗原理

人工智慧

圖靈對於人工智慧的定義大家都已耳熟能詳,但"是什麼構成了智慧"至今仍是一個帶有爭論的話題。電腦科學家們目前將人工智慧分成了多個分支,每一個分支都專注於解決一個特定的問題領域,舉其中三個有代表性的分支:

  • 基於預定義知識的邏輯與概率推理,比如模糊推理能夠幫助一個恆溫器根據監測到的溫度和溼度決定什麼時候開關空調。
  • 啟發式搜尋,比如在棋類遊戲中搜索到走下一子的最優解。
  • 機器學習,比如手寫字元識別系統。

簡單來說,機器學習的目的就是通過大量資料訓練一個能夠識別一種或多種模式的系統。訓練系統用的資料集合被稱作訓練集,如果訓練集的每個資料條目都打上目標輸出值(也就是標籤),則該方法稱作監督學習,不打標籤的則是非監督學習。機器學習中有多種演算法能夠實現手寫字元識別系統,在本課程中我們將基於神經網路實現該系統。

什麼是神經網路

神經網路由能夠互相通訊的節點構成,赫布理論解釋了人體的神經網路是如何通過改變自身的結構和神經連線的強度來記憶某種模式的。而人工智慧中的神經網路與此類似。請看下圖,最左一列藍色節點是輸入節點,最右列節點是輸出節點,中間節點是隱藏節點。該圖結構是分層的,隱藏的部分有時候也會分為多個隱藏層。如果使用的層數非常多就會變成我們平常說的深度學習了。

此處輸入圖片的描述

每一層(除了輸入層)的節點由前一層的節點加權加相加加偏置向量並經過啟用函式得到,公式如下:

此處輸入圖片的描述

其中f是啟用函式,b是偏置向量,它們的作用會在之後說明。

這一類拓撲結構的神經網路稱作前饋神經網路,因為該結構中不存在迴路。有輸出反饋給輸入的神經網路稱作遞迴神經網路(RNN)。在本課程中我們使用前饋神經網路中經典的BP神經網路來實現手寫識別系統。

如何使用神經網路

很簡單,神經網路屬於監督學習,那麼多半就三件事,決定模型引數,通過資料集訓練學習,訓練好後就能到分類工具/識別系統用了。資料集可以分為2部分(訓練集,驗證集),也可以分為3部分(訓練集,驗證集,測試集),訓練集可以看作平時做的習題集(可反覆做),系統通過對比習題集的正確答案和自己的解答來不斷學習改良自己。測試集可以看作是高考,同一份試卷只能考一次,測試集一般不會透露答案。那麼驗證集是什麼呢?好比多個學生(類比用不同策略訓練出的多個神經網路)要參加一個名額只有兩三人的比賽,那麼就得給他們一套他們沒做過的卷子(驗證集)來逐出成績最好的幾個人,有時也使用驗證集決定模型引數。在本課程中資料集只劃分訓練集和驗證集。

系統構成

我們的OCR系統分為5部分,分別寫在5個檔案中:

  • 客戶端(ocr.js
  • 伺服器(server.py
  • 使用者介面(ocr.html
  • 神經網路(ocr.py)
  • 神經網路設計指令碼(neural_network_design.py)

使用者介面(ocr.html)是一個html頁面,使用者在canvans上寫數字,之後點選選擇訓練或是預測。客戶端(ocr.js)將收集到的手寫數字組合成一個數組傳送給伺服器端(server.py)處理,伺服器呼叫神經網路模組(ocr.py),它會在初始化時通過已有的資料集訓練一個神經網路,神經網路的資訊會被儲存在檔案中,等之後再一次啟動時使用。最後,神經網路設計指令碼(neural_network_design.py)是用來測試不同隱藏節點數下的效能,決定隱藏節點數用的。

四、實驗步驟

我們將根據系統構成的五部分一一實現,在講解完每一部分的核心程式碼後給出完整的檔案程式碼。

實現使用者介面

需要給予使用者輸入資料、預測、訓練的介面,這部分較簡單,所以直接給出完整程式碼:

<!-- index.html --><!DOCTYPE html><html><head>    <scriptsrc="ocr.js"></script></head><bodyonload="ocrDemo.onLoadFunction()">    <divid="main-container"style="text-align: center;">        <h1>OCR Demo</h1>        <canvasid="canvas"width="200"height="200"></canvas>        <formname="input">            <p>Digit: <inputid="digit"type="text"> </p>            <inputtype="button"value="Train"onclick="ocrDemo.train()">            <inputtype="button"value="Test"onclick="ocrDemo.test()">            <inputtype="button"value="Reset"onclick="ocrDemo.resetCanvas();"/>        </form>     </div></body></html>

開一個伺服器看一下頁面效果:

python -m SimpleHTTPServer 3000

開啟瀏覽器位址列輸入localhost:3000

頁面效果如下圖:

此處輸入圖片的描述

手寫輸入等主要的客戶端邏輯需要在ocr.js檔案中實現。

實現客服端

畫布設定了200*200,但我們並不需要200*200這麼精確的輸入資料,20*20就很合適。

var ocrDemo = {    CANVAS_WIDTH: 200,    TRANSLATED_WIDTH: 20,    PIXEL_WIDTH: 10, // TRANSLATED_WIDTH = CANVAS_WIDTH / PIXEL_WIDTH

在畫布上加上網格輔助輸入和檢視:

    drawGrid: function(ctx) {        for (var x = this.PIXEL_WIDTH, y = this.PIXEL_WIDTH;                  x < this.CANVAS_WIDTH; x += this.PIXEL_WIDTH,                  y += this.PIXEL_WIDTH) {            ctx.strokeStyle = this.BLUE;            ctx.beginPath();            ctx.moveTo(x, 0);            ctx.lineTo(x, this.CANVAS_WIDTH);            ctx.stroke();            ctx.beginPath();            ctx.moveTo(0, y);            ctx.lineTo(this.CANVAS_WIDTH, y);            ctx.stroke();        }    },

我們使用一維陣列來儲存手寫輸入,0代表黑色(背景色),1代表白色(筆刷色)。

手寫輸入與儲存的程式碼:

    onMouseMove: function(e, ctx, canvas) {        if (!canvas.isDrawing) {            return;        }        this.fillSquare(ctx,             e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);    },    onMouseDown: function(e, ctx, canvas) {        canvas.isDrawing = true;        this.fillSquare(ctx,             e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);    },    onMouseUp: function(e) {        canvas.isDrawing = false;    },    fillSquare: function(ctx, x, y) {        var xPixel = Math.floor(x / this.PIXEL_WIDTH);        var yPixel = Math.floor(y / this.PIXEL_WIDTH);        //在這裡儲存輸入        this.data[((xPixel - 1)  * this.TRANSLATED_WIDTH + yPixel) - 1] = 1;        ctx.fillStyle = '#ffffff'; //白色        ctx.fillRect(xPixel * this.PIXEL_WIDTH, yPixel * this.PIXEL_WIDTH,             this.PIXEL_WIDTH, this.PIXEL_WIDTH);    },

下面完成在客戶端點選訓練鍵時觸發的函式。

當客戶端的訓練資料到達一定數量時,就一次性傳給伺服器端給神經網路訓練用:

    train: function() {        var digitVal = document.getElementById("digit").value;        // 如果沒有輸入標籤或者沒有手寫輸入就報錯        if (!digitVal || this.data.indexOf(1) < 0) {            alert("Please type and draw a digit value in order to train the network");            return;        }        // 將訓練資料加到客戶端訓練集中        this.trainArray.push({"y0": this.data, "label": parseInt(digitVal)});        this.trainingRequestCount++;        // 訓練資料到達指定的量時就傳送給伺服器端        if (this.trainingRequestCount == this.BATCH_SIZE) {            alert("Sending training data to server...");            var json = {                trainArray: this.trainArray,                train: true            };            this.sendData(json);            // 清空客戶端訓練集            this.trainingRequestCount = 0;            this.trainArray = [];        }    },

為什麼要設定BATCH_SIZE呢?這是為了防止伺服器在短時間內處理過多請求而降低了伺服器的效能。

接著完成在客戶端點選測試鍵(也就是預測)時觸發的函式:

    test: function() {        if (this.data.indexOf(1) < 0) {            alert("Please draw a digit in order to test the network");            return;        }        var json = {            image: this.data,            predict: true        };        this.sendData(json);    },

最後,我們需要處理在客戶端接收到的響應,這裡只需處理預測結果的響應:

    receiveResponse: function(xmlHttp) {        if (xmlHttp.status != 200) {            alert("Server returned status " + xmlHttp.status);            return;        }        var responseJSON = JSON.parse(xmlHttp.responseText);        if (xmlHttp.responseText && responseJSON.type == "test") {            alert("The neural network predicts you wrote a \'"                    + responseJSON.result + '\'');        }    },    onError: function(e) {        alert("Error occurred while connecting to server: " + e.target.statusText);    },

ocr.js的完整程式碼如下:

var ocrDemo = {    CANVAS_WIDTH: 200,    TRANSLATED_WIDTH: 20,    PIXEL_WIDTH: 10, // TRANSLATED_WIDTH = CANVAS_WIDTH / PIXEL_WIDTH    BATCH_SIZE: 1,    // 伺服器端引數    PORT: "9000",    HOST: "http://localhost",    // 顏色變數    BLACK: "#000000",    BLUE: "#0000ff",    // 客戶端訓練資料集    trainArray: [],    trainingRequestCount: 0,    onLoadFunction: function() {        this.resetCanvas();    },    resetCanvas: function() {        var canvas = document.getElementById('canvas');        var ctx = canvas.getContext('2d');        this.data = [];        ctx.fillStyle = this.BLACK;        ctx.fillRect(0, 0, this.CANVAS_WIDTH, this.CANVAS_WIDTH);        var matrixSize = 400;        while (matrixSize--) this.data.push(0);        this.drawGrid(ctx);        // 繫結事件操作        canvas.onmousemove = function(e) { this.onMouseMove(e, ctx, canvas) }.bind(this);        canvas.onmousedown = function(e) { this.onMouseDown(e, ctx, canvas) }.bind(this);        canvas.onmouseup = function(e) { this.onMouseUp(e, ctx) }.bind(this);    },    drawGrid: function(ctx) {        for (var x = this.PIXEL_WIDTH, y = this.PIXEL_WIDTH; x < this.CANVAS_WIDTH; x += this.PIXEL_WIDTH, y += this.PIXEL_WIDTH) {            ctx.strokeStyle = this.BLUE;            ctx.beginPath();            ctx.moveTo(x, 0);            ctx.lineTo(x, this.CANVAS_WIDTH);            ctx.stroke();            ctx.beginPath();            ctx.moveTo(0, y);            ctx.lineTo(this.CANVAS_WIDTH, y);            ctx.stroke();        }    },    onMouseMove: function(e, ctx, canvas) {        if (!canvas.isDrawing) {            return;        }        this.fillSquare(ctx, e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);    },    onMouseDown: function(e, ctx, canvas) {        canvas.isDrawing = true;        this.fillSquare(ctx, e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);    },    onMouseUp: function(e) {        canvas.isDrawing = false;    },    fillSquare: function(ctx, x, y) {        var xPixel = Math.floor(x / this.PIXEL_WIDTH);        var yPixel = Math.floor(y / this.PIXEL_WIDTH);        // 儲存手寫輸入資料        this.data[((xPixel - 1)  * this.TRANSLATED_WIDTH + yPixel) - 1] = 1;        ctx.fillStyle = '#ffffff';        ctx.fillRect(xPixel * this.PIXEL_WIDTH, yPixel * this.PIXEL_WIDTH, this.PIXEL_WIDTH, this.PIXEL_WIDTH);    },    train: function() {        var digitVal = document.getElementById("digit").value;        if (!digitVal || this.data.indexOf(1) < 0) {            alert("Please type and draw a digit value in order to train the network");            return;        }        // 將資料加入客戶端訓練資料集        this.trainArray.push({"y0": this.data, "label": parseInt(digitVal)});        this.trainingRequestCount++;        // 將客服端訓練資料集傳送給伺服器端        if (this.trainingRequestCount == this.BATCH_SIZE) {            alert("Sending training data to server...");            var json = {                trainArray: this.trainArray,                train: true            };            this.sendData(json);            this.trainingRequestCount = 0;            this.trainArray = [];        }    },    // 傳送預測請求    test: function() {        if (this.data.indexOf(1) < 0) {            alert("Please draw a digit in order to test the network");            return;        }        var json = {            image: this.data,            predict: true        };        this.sendData(json);    },    // 處理伺服器響應    receiveResponse: function(xmlHttp) {        if (xmlHttp.status != 200) {            alert("Server returned status " + xmlHttp.status);            return;        }        var responseJSON = JSON.parse(xmlHttp.responseText);        if (xmlHttp.responseText && responseJSON.type == "test") {            alert("The neural network predicts you wrote a \'" + responseJSON.result + '\'');        }    },    onError: function(e) {        alert("Error occurred while connecting to server: " + e.target.statusText);    },    sendData: function(json) {        var xmlHttp = new XMLHttpRequest();        xmlHttp.open('POST', this.HOST + ":" + this.PORT, false);        xmlHttp.onload = function() { this.receiveResponse(xmlHttp); }.bind(this);        xmlHttp.onerror = function() { this.onError(xmlHttp) }.bind(this);        var msg = JSON.stringify(json);        xmlHttp.setRequestHeader('Content-length', msg.length);        xmlHttp.setRequestHeader("Connection", "close");        xmlHttp.send(msg);    }}

效果如下圖:

此處輸入圖片的描述

實現伺服器端

伺服器端由Python標準庫BaseHTTPServer實現,我們接收從客戶端發來的訓練或是預測請求,使用POST報文,由於邏輯簡單,方便起見,兩種請求就發給同一個URL了,在實際生產中還是分開比較好。

完整程式碼如下:

# -*- coding: UTF-8 -*-import BaseHTTPServerimport jsonfrom ocr import OCRNeuralNetworkimport numpy as npimport random#伺服器端配置HOST_NAME = 'localhost'PORT_NUMBER = 9000#這個值是通過執行神經網路設計指令碼得到的最優值HIDDEN_NODE_COUNT = 15# 載入資料集data_matrix = np.loadtxt(open('data.csv', 'rb'), delimiter = ',')data_labels = np.loadtxt(open('dataLabels.csv', 'rb'))# 轉換成list型別data_matrix = data_matrix.tolist()data_labels = data_labels.tolist()# 資料集一共5000個數據,train_indice儲存用來訓練的資料的序號train_indice = range(5000)# 打亂訓練順序random.shuffle(train_indice)nn = OCRNeuralNetwork(HIDDEN_NODE_COUNT, data_matrix, data_labels, train_indice);classJSONHandler(BaseHTTPServer.BaseHTTPRequestHandler):    """處理接收到的POST請求"""    defdo_POST(self):        response_code = 200        response = ""        var_len = int(self.headers.get('Content-Length'))        content = self.rfile.read(var_len);        payload = json.loads(content);        # 如果是訓練請求,訓練然後儲存訓練完的神經網路        if payload.get('train'):            nn.train(payload['trainArray'])            nn.save()        # 如果是預測請求,返回預測值        elif payload.get('predict'):            try:                print nn.predict(data_matrix[0])                response = {"type":"test", "result":str(nn.predict(payload['image']))}            except:                response_code = 500        else:            response_code = 400        self.send_response(response_code)        self.send_header("Content-type", "application/json")        self.send_header("Access-Control-Allow-Origin", "*")        self.end_headers()        if response:            self.wfile.write(json.dumps(response))        returnif __name__ == '__main__':    server_class = BaseHTTPServer.HTTPServer;    httpd = server_class((HOST_NAME, PORT_NUMBER), JSONHandler)    try:        #啟動伺服器        httpd.serve_forever()    except KeyboardInterrupt:        pass    else:        print "Unexpected server exception occurred."    finally:        httpd.server_close()

實現神經網路

如之前所說,我們使用反向傳播演算法(Backpropagation)來訓練神經網路,演算法背後的原理推導推薦閱讀這篇博文:反向傳播神經網路極簡入門

演算法主要分為三個步驟:

第一步:初始化神經網路

一般將所有權值與偏置量置為(-1,1)範圍內的隨機數,在我們這個例子中,使用(-0.06,0.06)這個範圍,輸入層到隱藏層的權值儲存在矩陣theta1中,偏置量存在input_layer_bias中,隱藏層到輸出層則分別存在theta2hidden_layer_bias中。

建立隨機矩陣的程式碼如下,注意輸出的矩陣是以size_out為行,size_in為列。可能你會想為什麼不是size_in在左邊。你可以這麼想,一般都是待處理的輸入放在右邊,處理操作(矩陣)放在左邊。

def_rand_initialize_weights(self, size_in, size_out):    return [((x * 0.12) - 0.06) for x in np.random.rand(size_out, size_in)]

初始化權值矩陣與偏置向量:

self.theta1 = self._rand_initialize_weights(400, num_hidden_nodes)self.theta2 = self._rand_initialize_weights(num_hidden_nodes, 10)self.input_layer_bias = self._rand_initialize_weights(1,                                                       num_hidden_nodes)self.hidden_layer_bias = self._rand_initialize_weights(1, 10)

這裡說明一下會用到的每一個矩陣/向量及其形狀:

變數名描述形狀
y0輸入層1 * 400
theta1輸入-隱藏層權值矩陣隱藏層節點數 * 400
input_layer_bias輸入-隱藏層偏置向量隱藏層節點數 * 1
y1隱藏層隱藏層節點數 * 1
theta2隱藏-輸出層權值矩陣10 * 隱藏層節點數
hidden_layer_bias隱藏-輸出層偏置向量10 * 1
y2輸出層10 * 1

第二步:前向傳播

前向傳播就是輸入資料通過一層一層計算到達輸出層得到輸出結果,輸出層會有10個節點分別代表0~9,哪一個節點的輸出值最大就作為我們的預測結果。還記得前面說的激發函式嗎?一般用sigmoid函式作為激發函式。

# sigmoid激發函式def_sigmoid_scalar(