1. 程式人生 > >matlab+BP神經網路實現手寫體數字識別

matlab+BP神經網路實現手寫體數字識別

個人部落格文章連結:http://www.huqj.top/article?id=168

接著上一篇所說的 BP神經網路,現在用它來實現一個手寫體數字的識別程式,訓練素材來自吳恩達機器學習課程,我把打包好上傳到了網盤上:

1

2

連結:https://pan.baidu.com/s/1_u8zXzkQcY0iS3cgq2k0xg 

提取碼:4opy

    訓練資料一共有5000條,10個數字(0~9,為了和matlab適配,0在這裡統一用10表示),每個數字各500個手寫體圖片,畫素統一處理為20*20,其中pics中是5000張圖片,   data是一個.mat檔案,可以直接載入到matlab中,包含兩個變數X(5000x400 double矩陣)和y(5000x1 int矩陣)。

image.png

image.png

image.png

 

    可以看到,訓練資料的輸入是400個畫素點的灰度值,雖然圖片是20x20的,但是為了處理方便將其轉換成1x400的輸入,可以用matlab中的reshape函式進行轉換。而對於輸出而言,這可以看作一個多元分類問題,一共有10種分類,所以輸出可以轉換成一個10維向量。定義好輸入輸出格式之後,再考慮下神經網路的架構,平衡效能和效率,最終選擇的架構是一個25元隱含層的BP網路。另外,為了衡量最終的模型效果,我們需要從5000個數據中抽取一部分作為測試集,這裡我每個數字選了10條資料作為測試資料集,不過理論上訓練集和測試集的比例可以達到 7:3

    利用之前編寫好的BP網路訓練函式和一些附加函式(sigmoid,預測函式等),最終的手寫體識別訓練程式如下:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

% 實現一個手寫體數字識別的神經網路訓練程式

clc;

clear;

 

load('machine-learning-ex3\\ex3\\ex3data1.mat');

 

% 展示100張圖片,也是測試集

testData = [X(1:10, :); X(501:510, :); X(1001:1010, :); X(1501:1510, :); X(2001:2010, :); 

    X(2501:2510, :); X(3001:3010, :); X(3501:3510, :); X(4001:4010, :); X(4501:4510, :)];

testResult = [y(1:10, :); y(501:510, :); y(1001:1010, :); y(1501:1510, :); y(2001:2010, :); 

    y(2501:2510, :); y(3001:3010, :); y(3501:3510, :); y(4001:4010, :); y(4501:4510, :)];

% displayData(testData, 20);

 

% 準備訓練資料

num = 400;  % 每個數字訓練資料集的大小,最大490

trainingData = [X(11:10 + num, :); X(511:510 + num, :); X(1011:1010 + num, :); X(1511:1510 + num, :); X(2011:2010 + num, :); 

    X(2511:2510 + num, :); X(3011:3010 + num, :); X(3511:3510 + num, :); X(4011:4010 + num, :); X(4511:4510 + num, :)];

trainingY = [y(11:10 + num, :); y(511:510 + num, :); y(1011:1010 + num, :); y(1511:1510 + num, :); y(2011:2010 + num, :); 

    y(2511:2510 + num, :); y(3011:3010 + num, :); y(3511:3510 + num, :); y(4011:4010 + num, :); y(4511:4510 + num, :)];

trainingResult = zeros(length(trainingY), 10);

for i = 1 : size(trainingResult, 1)

    trainingResult(i, trainingY(i)) = 1;  % 相應的數字位下標為1,注意10為1表示數字是0

end

 

% 模型引數

size_ = [400, 25, 10];   % 輸入為400個畫素點的灰度值,輸出為一個10維向量

alpha = 0.8;

lambda = 0.5;

threshold = 0.01;

load('weight.mat', 'W');

% W = [];

while (true)

    maxIter = input('輸入想要迭代的最大次數,輸入-1結束:\n');

    if (maxIter == -1)

        break;

    end

    [W, delta, IterNum] = BPNN(size_, alpha, lambda, threshold, maxIter, trainingData, trainingResult, W);

    fprintf('delta=%f, iteration num=%d',delta, IterNum);

end

 

% 儲存權重

save('weight.mat', W);

 

% 測試準確率

[res, ~] = BPNNPredict(size_, W, testData, zeros(length(testResult), 10));

precious = 0;

for i = 1 : length(testResult)

   tmp = max(res(i, :));

   res(i, :) = (res(i, :) == tmp);

   if res(i, round(testResult(i))) == 1

       precious = precious + 1;

   end

end

fprintf('pricision: %f\n', precious / length(testResult));

    因為一開始不知道要迭代多少次,所以設定成了一個迴圈的結構,可以根據訓練誤差決定繼續訓練或者結束訓練,然後將模型權重儲存下來,下次可以接著訓練。

    如果想要在matlab中畫出圖片,可以將這一行的註釋去掉:

1

% displayData(testData, 20);

    然後繪出所有測試集的圖片如下:

image.png

    執行程式反覆迭代上萬次之後,在測試集上的準確率穩定在92%左右,這可能也是受模型和資料集的限制。而且這個模型只是用於黑底白字的圖片,用我自己的手寫數字測試效果並不太好(可能與我的圖片處理有關),最高只能達到 7/10 的準確率,後續會持續考慮改進模型。

    完整程式碼下載地址: https://download.csdn.net/download/qq_32216775/10897369