1. 程式人生 > >基於BP神經網路的數字識別基礎系統(四)

基於BP神經網路的數字識別基礎系統(四)

基於BP神經網路的數字識別基礎系統(四)

接上篇

上一篇的連結:http://blog.csdn.net/z_x_1996/article/details/68490009

3.系統設計

上一篇筆者已經討論完了BP神經網路需要用到的知識點,接下來就開始設計符合我們標題的系統了。

首先我們要確定訓練集以及測試集:下載連結:http://download.csdn.net/detail/z_x_1996/9799552


我們來分析訓練集,首先訓練的圖片格式為bmp點陣圖格式,位深度為8,解析度為32*64,訓練集分為0~9十個資料夾,每個資料夾裡面有4張不同字型的相同數字(數字同文件夾名稱),同時訓練集裡有一個target.txt檔案,裡面檔案代表每一張圖片的目標輸出,一行就是一張圖的目標輸出,我們很容易看出輸出有10個單元,每個數字對應一組輸出。這裡並沒有採用二進位制編碼而是採用一對一編碼,這樣的好處在於可以很容易獲得置信度,但是壞處也是顯而易見的,那就是當樣本型別很多時網路的輸出會急劇增加。

我們再來看輸入層,為了精簡輸入資訊,我們將圖片壓縮,橫豎均只取1/4的畫素,均勻分佈。這樣輸入單元有32*64/16=128個輸入單元。

隱藏層有多種選擇,首先確定隱藏層數,考慮到該資料組分類比較簡單,故選擇一層隱藏層,這層的單元數有多種選擇,不同的選擇會有不同的影響,這個影響我們後面再談(如果忘了請記得提醒筆者),這裡我們選擇為4個。

至此我們便確定了網路結構,三層:

  • 輸入層:128單元
  • 隱藏層:8單元
  • 輸出層:10單元

這樣我們也可以把權重向量的size確定了:

  • weightHK[][]:10x(8+1)
  • weightIH[][]:8x(128+1)

(這裡+1的原因是要加上一個常數偏置項)

首先筆者先給出系統工程的結構圖:

3.1 神經網路包

我們先構建神經網路元素包 com.zhangxiao.element。

首先自然來到我們SNeuron.java檔案,該檔案為一個神經元。

package com.zhangxiao.element;

public class SNeuron {
    private double[] weight;
    private double[] input;
    private int length;

    public SNeuron(double[] input,double[] weight){
        this.input = input;
        this.length = input.length;
        this.weight = weight;
    }

    //獲得Sigmoid輸出結果。
    public double getResult(){
        double sum = weight[0];
        for(int i=0;i<length;i++){
            sum += input[i]*weight[i+1];
        }
        return 1/(Math.exp(-sum)+1);
    }

}

沒有什麼好說的,然後是構建一層 Layer.java,該檔案為一層的類。

package com.zhangxiao.element;

public class Layer {

    private SNeuron[] cells;
    private int number;
    private double[] input;
    private double[] output;
    private double[][] weight;
    //初始化神經層
    public Layer(int number,double[] input,double[][] weight){
        this.number = number;
        this.input = input;
        this.weight = weight;
        output = new double[number];
        cells = new SNeuron[number];
        for(int i=0;i<number;i++){
            cells[i] = new SNeuron(this.input,this.weight[i]);
        }
    }
    //獲得神經層輸出結果陣列
    public void goForward() {
        for(int i=0;i<number;i++){
            output[i] = cells[i].getResult();
        }
    }

    public double[] getOutput() {
        return output;
    }
}

然後是構建一個神經系統(目前筆者寫的程式碼只支援3層,即一個隱藏層),NervousSystem1H.java

package com.zhangxiao.element;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;

public class NervousSystem1H {
    private double[][] trainData;
    private double[] input;
    private double[] output;
    private double[] connection;
    public double[] getInputLayer() {
        return input;
    }

    private double[][] target;
    private Layer[] layers;
    private int[] structure;
    private double efficiency;
    private double[] deltaK;
    private double[] deltaH;

    private double[][] weightIH;
    private double[][] weightHK;        

    //初始化神經系統
    public NervousSystem1H(double efficiency,int[] structure,double[][] trainData,double[][] target) throws IOException{

        if(trainData[0].length!=structure[0]){
            System.out.println("訓練資料長度與輸入層長度不一致!");
            return;
        }

        this.trainData = trainData;
        this.target = target;
        this.efficiency = efficiency;
        this.structure = structure;

        //初始化陣列
        this.input = new double[structure[0]];
        deltaK = new double[structure[2]];
        deltaH = new double[structure[1]];      
        for(int k=0;k<deltaK.length;k++){
            deltaK[k] = 0;
        }
        for(int h=0;h<deltaH.length;h++){
            deltaH[h] = 0;
        }
        weightIH = new double[structure[1]][structure[0]+1];
        weightHK = new double[structure[2]][structure[1]+1];
        for(int h=0;h<structure[1];h++){
            for(int i=0;i<structure[0]+1;i++){
                while(Math.abs((weightIH[h][i] = Math.random()/10-0.05))==0){}
            }
        }
        for(int k=0;k<structure[2];k++){
            for(int h=0;h<structure[1]+1;h++){
                while(Math.abs(weightHK[k][h] = Math.random()/10-0.05)==0){}
            }
        }

        //連線各層
        layers= new Layer[2];
        layers[0] = new Layer(structure[1],this.input,weightIH);
        connection = layers[0].getOutput();
        layers[1] = new Layer(structure[2],connection,weightHK);
        this.output = layers[1].getOutput();

    }

    //訓練神經網路
    public void train() throws IOException{
        double error = 0;
        int process = 0;
        while((error = getError())>0.0001){
            System.out.println(process++ +":"+error);
            for(int d=0;d<trainData.length;d++){
                //正向傳播輸出
                goForward(trainData[d]);

                double[] outputK = layers[1].getOutput();
                double[] outputH = layers[0].getOutput();

                for(int k=0;k<deltaK.length;k++){
                    deltaK[k] = outputK[k]*(1-outputK[k])*(target[d][k]-outputK[k]);
                }
                for(int h=0;h<deltaH.length;h++){
                    deltaH[h] = 0;
                    for(int k=0;k<deltaK.length;k++){
                        deltaH[h] += outputH[h]*(1-outputH[h])*deltaK[k]*weightHK[k][h+1];
                    }
                }
                //更新權值

                for(int k=0;k<weightHK.length;k++){
                    weightHK[k][0] += efficiency*deltaK[k];
                    for(int h=1;h<weightHK[0].length;h++){
                        weightHK[k][h] += efficiency*deltaK[k]*outputH[h-1];
                    }
                }

                for(int h=0;h<weightIH.length;h++){
                    weightIH[h][0] += efficiency*deltaH[h];
                    for(int i=1;i<weightIH[0].length;i++){
                        weightIH[h][i] += efficiency*deltaH[h]*trainData[d][i-1];
                    }
                }
            }
        }
        System.out.println("最終誤差為:"+getError());
    }

    //獲取輸出結果陣列
    public void goForward(double[] input){
        setInput(input);
        for(int i = 0;i<structure.length-1;i++){
            layers[i].goForward();
        }
    }

    //獲取誤差
    public double getError(){
        double error = 0;
        for(int d=0;d<trainData.length;d++){
            goForward(trainData[d]);
            for(int i=0;i<target[0].length;i++){
                error += 0.5*(target[d][i]-output[i])*(target[d][i]-output[i]);
            }
        }
        return error/trainData.length/10;
    }

    //將訓練好的權重儲存到txt檔案中方便檢視以及二次呼叫
    public boolean saveWeight(File file) throws IOException{
        boolean flag = false;
        BufferedWriter bw = new BufferedWriter(new FileWriter(file));
        //寫入weightIH
        for(int h=0;h<weightIH.length;h++){
            for(int i=0;i<weightIH[0].length;i++){
                bw.append(Double.toString(weightIH[h][i])+" ");
            }
            bw.append("\r\n");
            bw.flush();
        }
        //寫入weightHK
        for(int k=0;k<weightHK.length;k++){
            for(int h=0;h<weightHK[0].length;h++){
                bw.append(Double.toString(weightHK[k][h])+" ");
            }
            bw.append("\r\n");
            bw.flush();
        }
        bw.close();
        return flag;
    }

    //呼叫訓練好的網路
    public boolean loadWeight(File file) throws IOException{
        boolean flag = false;
        BufferedReader br = new BufferedReader(new FileReader(file));
        //寫入weightIH
        String line;
        String[] strs;
        for(int h=0;h<weightIH.length;h++){
            line=br.readLine();
            strs = line.split(" ");
            for(int i=0;i<weightIH[0].length;i++){
                weightIH[h][i] = Double.parseDouble(strs[i]);
            }
        }
        //寫入weightHK
        for(int k=0;k<weightHK.length;k++){
            line=br.readLine();
            strs = line.split(" ");
            for(int h=0;h<weightHK[0].length;h++){
                weightHK[k][h] = Double.parseDouble(strs[h]);
            }
        }
        br.close();
        return flag;
    }

    //網路每個輸出單元的輸出
    public double[] predict_all(double[] input){
        goForward(input);
        return output;
    }

    //輸出預測數字
    public int preidict_result(double[] input){
        int result = -1;
        double max = -1;
        goForward(input);
        for(int i=0;i<output.length;i++){
            if(output[i]>max){
                max = output[i];
                result = 9-i;
            }
        }
        return result;
    }

    private void setInput(double[] input) {
        for(int i=0;i<this.input.length;i++){
            this.input[i] = input[i];
        }
    }

    public double[][] getWeightIH() {
        return weightIH;
    }

    public double[][] getWeightHK() {
        return weightHK;
    }

}

這裡需要說明的是主要的計算量為goForward函式,這個是正向計算的函式。如果看懂了前面的原理這個檔案其實也沒什麼好講的,無非是把輸出細節化,訓練方法和前面所說一樣。同樣增加了getError函式來獲取誤差,因為筆者把Error來作為訓練終止的要求。但是其實使用這個作為終止條件摒棄了增量梯度下降演算法中不需要一次性載入所有資料的優點。計算Error必須使用所有的資料。

這樣一個網路的架構就已經搭建好了,使用時我們只需要呼叫NervousSystem1H類中的方法就可以了。

3.2 主程式包

下面就是要建立針對專案的主程式包com.zhangxiao.window了。

Window.java中主要應該包括如下方法:

  • 獲取訓練資料
  • 圖片資料轉化為陣列
  • 獲取訓練標籤
  • 構建神經網路

    package com.zhangxiao.window;
    
    import java.awt.image.BufferedImage;
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import javax.imageio.ImageIO;
    
    import com.zhangxiao.element.NervousSystem1H;
    
    public class Window {
    
        public static void main(String[] args) throws IOException {
            String path = "這裡填自己的路徑";
            //獲取訓練素材
            double[] testData = new double[128];//這裡記得獲取測試資料!!!!!!!!!!!!!!!!!!
            double[][] target = getTarget(path, 10, 40);
            double[][] trainData = getTrainData(path, 40);
            int[] structure = new int[]{128,8,10};
            //構建神經網路
            NervousSystem1H s = new NervousSystem1H(0.01,structure,trainData,target);
            //訓練神經網路
            System.out.println("訓練中...");
            s.train();
            System.out.println("訓練完畢!");    
            //儲存weight資料
            s.saveWeight(new File("data/weight/weight.txt"));
    
            //載入儲存的weight資料
            /*System.out.println("載入中...");
            s.loadWeight(new File("data/weight/weight.txt"));
            System.out.println("載入完成!");*/
    
            double[] result = s.predict_all(testData);
            for(int i=0;i<result.length;i++){
                System.out.print(result[i]+" ");
            }
        }
    
        //獲取訓練樣本
        public static double[][] getTrainData(String direction,int number) throws IOException{
            double[][] trainData = new double[number][128];
            for(int d=0;d<number/4;d++){
                for(int i=0;i<4;i++){
                    trainData[4*d+i] = image2Array(direction+"/"+d+"/"+d+""+i+".bmp");
                }
            }
            return trainData;
        }
    
        //將圖片轉化為陣列
        public static double[] image2Array(String str) throws IOException{
            double[] data = new double[16*8];
            BufferedImage image = ImageIO.read(new File(str));
            for(int i = 0;i<8;i++){
                for(int j = 0;j<16;j++){
                    int color = image.getRGB(4*i, 4*j);
                    int b = color&0xff;
                    int g = (color>>8)&0xff;
                    int r = (color>>8)&0xff;
                    data[8*j+i]=((int)(r*0.3+g*0.59+b*0.11))/255;
                }
            }
            return data;
        }
    
        //獲取目標結果陣列
        @SuppressWarnings("resource")
        public static double[][] getTarget(String str,int length,int number) throws IOException{
            BufferedReader br = new BufferedReader(new FileReader(str));
            double[][] data = new double[number][length];
            String line;
            String[] strs;
            int d = 0;
            while((line=br.readLine())!=null){
                strs = line.split(" ");
                for(int i=0;i<length;i++){
                    data[d][i] = Double.parseDouble(strs[i]);
                }
                d++;
            }
            if(d!=number){
                System.out.println("資料組數不匹配!");
                return null;
            }
            br.close();
            return data;
        }
    
    }
    

4.後記

到這裡這個坑基本上算是填完了,當然筆者還是需要說明的是由於程式碼寫的比較匆忙,很多冗餘、不夠優化以及結構問題比比皆是。希望大家能夠諒解,如果有很好的建議方便留言。到目前為止這個系列筆者前前後後花費了很多的精力以及時間,終於完成了這個兩萬多字的系列,可以說從中也學到了很多東西,很多以前並不是很清楚的東西也理清楚了。另外這裡給出大家一個優化的方向:

  • 加入衝量項,避開區域性最小值。
  • 改變隱藏層的單元數。

如果覺得看完還是有些疑惑的建議自己再復建一下演算法,或者你可以試試將隱藏層數變為2層,再來思考整個系統,相信你會受益匪淺!