10 行程式碼,實現手寫數字識別
識別手寫的阿拉伯數字,對於人類來說十分簡單,但是對於程式來說還是有些複雜的。
不過隨著機器學習技術的普及,使用10幾行程式碼,實現一個能夠識別手寫數字的程式,並不是一件難事。這是因為有太多的機器學習模型可以拿來直接用,比如tensorflow、caffe,在python下都有現成的安裝包,寫一個識別數字的程式,10幾行程式碼足夠了。
然而我想做的,是不借助任何第三方的庫,從零開始,完全自己實現一個這樣的程式。之所以這麼做,是因為自己動手實現,才能深入瞭解機器學習的原理。
1 模型實現
1.1 原理
熟悉神經網路迴歸演算法的,可以略過這一節了。
學習了一些基本概念,決定使用迴歸演算法。首先下載了著名的MNIST資料集,這個資料集有60000個訓練樣本,和10000個測試樣本。每個數字圖片都是28*28的灰度圖片,所以輸入可以認為是一個28*28的矩陣,也可以認為是一個28*28=784個畫素值。
這裡定義一個模型用於判斷一個圖片數字,每個模型包括每個輸入的權重,加一個截距,最後再做個歸一。模型的表示式:
Out5= sigmoid(X0*W0+ X1*W1+……X783*W783+bias)
X0到X783是784個輸入,W0到W783是784個權重,bias是一個常量。sigmoid函式可以將較大範圍的數擠壓到(0,1)區間內,也就是歸一。
例如我們用這一組權重和bias來判斷數字5,期望當圖片是5時輸出是1,當不是5時輸出是0。然後訓練的過程就是根據每個樣本的輸入,計算Out5的值和正確值(0或1)的差距,然後根據這個差距,調整權重和bias。轉換一下公式,就是在努力使得(Out5-正確值)接近於0,即所謂損失最小。
同理,10個數字就要有10套模型,每個判斷不同的數字。訓練好以後,一個圖片來了,用這10套模型進行計算,哪個模型計算的結果更接近於1,就認為這個圖片是哪個數字。
1.2 訓練
按照上面的思路,使用集算器的SPL(結構化處理語言)來編碼實現:
A |
B |
C |
|
1 |
=file("train-imgs.btx")[email protected]() |
||
2 |
>x=[],wei=[],bia=[],v=0.0625,cnt=0 |
||
3 |
for 10 |
>wei.insert(0,[to(28*28).(0)]), bia.insert(0,0.01) |
|
4 |
for 50000 |
>label=A1.fetch(1)(1) |
|
5 |
>y=to(10).(0), y(label+1)=1,x=[] |
||
6 |
>x.insert(0,A1.fetch(28*28)) |
>x=x.(~/255) |
|
7 |
=wei.(~**x).(~.sum()) ++ bia |
||
8 |
=B7.(1/(1+exp(-~))) |
||
9 |
=(B8--y)**(B8.(1-~))**B8 |
||
10 |
for 10 |
>wei(B10)=wei(B10)--x.(~*v*B9(B10)), bia(B10)=bia(B10) - v*B9(B10) |
|
11 |
>file("MNIST模型.btx")[email protected](wei), file("MNIST模型.btx")[email protected](bia) |
不用再找了,訓練模型的所有程式碼都在這裡了,沒有用到任何第三方庫,下面解析一下:
A1,用遊標匯入MNIST訓練樣本,這個是我轉換過的格式,可以被集算器直接訪問;
A2,定義變數:輸入x,權重wei,訓練速度v,等;
A3,B3,初始化10組模型(每組是784個權重+1個bias);
A4,迴圈取5萬個樣本進行訓練,10模型同時訓練;
B4,取出來label,即這個圖片是幾;
B5,計算正確的10個輸出,儲存到變數y;
B6,取出來這個圖片的28*28個畫素點作為輸入,C6把每個輸入除以255,這是為了歸一化;
B7,計算X0*W0+ X1*W1+……X783*W783+bias
B8,計算sigmoid(B7)
B9,計算B8的偏導,或者叫梯度;
B10,C10,根據B9的值,迴圈調整10個模型的引數;
A11,訓練完畢,把模型儲存到檔案。
1.3 測試
測試一下這個模型的成功率吧,用 SPL 寫了一個測試程式:
A |
B |
C |
|
1 |
=file("MNIST模型.btx")[email protected]() |
=[0,1,2,3,4,5,6,7,8,9] |
|
2 |
>wei=A1.fetch(10),bia=A1.fetch(10) |
||
3 |
>cnt=0 |
||
4 |
=file("test-imgs.btx")[email protected]() |
||
5 |
for 10000 |
>label=A4.fetch(1)(1) |
|
6 |
>x=[] |
||
7 |
>x.insert(0,A4.fetch(28*28)) |
>x=x.(~/255) |
|
8 |
=wei.(~**x).(~.sum()) ++ bia |
||
9 |
=B8.(round(1/(1+exp(-~)), 2)) |
||
10 |
=B9.pmax() |
||
11 |
if label==B1(B10) |
>cnt=cnt+1 |
|
12 |
=A1.close() |
||
13 |
=output(cnt/100) |
執行測試,正確率達到了91.1%,我對這個結果是很滿意的,畢竟這只是一個單層模型,我用TensorFlow的單層模型得到的正確率也是91%多一點。下面解析一下程式碼:
A1,匯入模型檔案;
A2,把模型提取到變數裡;
A3,計數器初始化(用於計算成功率);
A4,匯入MNIST測試樣本,這個檔案格式是我轉換過的;
A5,迴圈取1萬個樣本進行測試;
B5,取出來label;
B6,清空輸入;
B7,取出來這個圖片的28*28個畫素點作為輸入,每個輸入除以255,這是為了歸一化;
B8,計算X0*W0+ X1*W1+……X783*W783+bias
B9,計算sigmoid(B7)
B10,得到最大值,即最可能的那個數字;
B11,判斷正確測計數器加一;
A12,A13,測試結束,關閉檔案,輸出正確率。
1.4 優化
這裡要說的優化並不是繼續提高正確率,而是提升訓練的速度。想提高正確率的同學可以嘗試一下這幾個手段:
1. 加一個卷積層;
2. 學習速度不要用固定值,而是隨著訓練次數遞減;
3. 權重的初始值不要使用全零,使用正態分佈;
我認為單純追求正確率的意義不大,因為MNIST資料集有些圖片本身就有問題,即使人工也不一定能知道寫的是數字幾。我用集算器顯示了幾張出錯的圖片,都是書寫十分不規範的,下面這個圖片很難看出來是2。
下面說重點,要提高訓練速度,可以使用並行或叢集。使用SPL語言實現並行很簡單,只要使用fork關鍵字,把上面的程式碼稍加處理就可以了。
A |
B |
C |
D |
|
1 |
=file("train-imgs.btx")[email protected]() |
|||
2 |
>x=[],wei=[],bia=[],v=0.0625,cnt=0 |
>mode=to(0,9) |
||
3 |
>wei=to(28*28).(0) |
|||
4 |
fork mode |
=A1.cursor() |
||
5 |
for 50000 |
>label=B4.fetch(1)(1) |
>y=1,x=[] |
|
6 |
if label!=A4 |
>y=0 |
||
7 |
>x.insert(0,B4.fetch(28*28)) |
>x=x.(~/255) |
||
8 |
=(wei**x).sum() + bia |
|||
9 |
=1/(1+exp(-C8)) |
|||
10 |
=(C9-y)*((1-C9))*C9 |
|||
11 |
>wei=wei--x.(~*v*C10), bia=bia- v*C10 |
|||
12 |
return wei,bia |
|||
13 |
=movefile(file("MNIST模型.btx")) |
|||
14 |
for 10 |
>file("MNIST模型.btx")[email protected]([A4(A15)(1)]) |
||
15 |
for 10 |
>file("MNIST模型.btx")[email protected]([A4(A16)(2)]) |
使用了並行之後,訓練的時間減少差不多一半,而程式碼並沒有做太多修改。
2 為什麼是 SPL 語言?
使用SPL語言在初期可能會有點不適應,用得多了會覺得越來越方便:
1. 支援集合運算,比如例子裡用到的784個輸入和784個權重的乘法,直接寫一個**就可以了,如果使用Java或者C,還要自己實現。
2. 資料的輸入輸出很方便,可以方便地對檔案讀寫。
3. 除錯太方便了,所有變數都直觀可見,這一點比python要好用。
4. 可以單步計算,有了改動不用從頭重來,Java和C做不到這一點,python雖然可以但也不方便,集算器只要點中相應格執行就可以了。
5. 實現並行和叢集很方便,不需要太多的開發工作量。
6. 支援呼叫和被呼叫。集算器可以呼叫第三方java庫,Java也可以呼叫集算器的程式碼,例如上面的程式碼就可以被Java呼叫,實現一個自動填驗證碼的功能。
這樣的程式語言,用在數學計算上,實在是最合適不過了。
作者:liwei
連結:http://c.raqsoft.com.cn/article/1540374496048
來源:乾學院
著作權歸作者所有。商業轉載請聯絡作者獲得授權,非商業轉載請註明出處。