1. 程式人生 > >機器學習02之BP神經網路圖解及JAVA實現

機器學習02之BP神經網路圖解及JAVA實現




package com.fei.bp02;
public class Bp {

    private double[] hide1_x;//// 輸入層即第一層隱含層的輸入;hide1_x[資料的特徵數目+1], hide1_x[0]為1
    private double[][] hide1_w;// 隱含層權值,hide1_w[本層的節點的數目][資料的特徵數目+1];hide_w[0][0]為偏置量
    private double[] hide1_errors;// 隱含層的誤差,hide1_errors[節點個數]

    private double[] out_x;// 輸出層的輸入值即第二次層隱含層的輸出 out_x[上一層的節點數目+1], out_x[0]為1
    private double[][] out_w;// 輸出層的權值 hide1_w[節點的數目][上一層的節點數目+1]//
                                // out_w[0][0]為偏置量
    private double[] out_errors;// 輸出層的誤差 hide1_errors[節點個數]

    private double[] target;// 目標值,target[輸出層的節點個數]

    private double rate;// 學習速率

    public Bp(int input_node, int hide1_node, int out_node, double rate) {
        super();

        // 輸入層即第一層隱含層的輸入
        hide1_x = new double[input_node + 1];

        // 第一層隱含層
        hide1_w = new double[hide1_node][input_node + 1];
        hide1_errors = new double[hide1_node];

        // 輸出層
        out_x = new double[hide1_node + 1];
        out_w = new double[out_node][hide1_node + 1];
        out_errors = new double[out_node];

        target = new double[out_node];

        // 學習速率
        this.rate = rate;
        init_weight();// 1.初始化網路的權值
    }

    /**
     * 初始化權值
     */
    public void init_weight() {

        set_weight(hide1_w);
        set_weight(out_w);
    }

    /**
     * 初始化權值
     * 
     * @param w
     */
    private void set_weight(double[][] w) {
        for (int i = 0, len = w.length; i != len; i++)
            for (int j = 0, len2 = w[i].length; j != len2; j++) {
                w[i][j] = 0;
            }
    }

    /**
     * 獲取原始資料
     * 
     * @param Data
     *            原始資料矩陣
     */
    private void setHide1_x(double[] Data) {
        if (Data.length != hide1_x.length - 1) {
            throw new IllegalArgumentException("資料大小與輸出層節點不匹配");
        }
        System.arraycopy(Data, 0, hide1_x, 1, Data.length);
        hide1_x[0] = 1.0;
    }

    /**
     * @param target
     *            the target to set
     */
    private void setTarget(double[] target) {
        this.target = target;
    }

    /**
     * 2.訓練資料集
     * 
     * @param TrainData
     *            訓練資料
     * @param target
     *            目標
     */
    public void train(double[] TrainData, double[] target) {
        // 2.1匯入訓練資料集和目標值
        setHide1_x(TrainData);
        setTarget(target);

        // 2.2:向前傳播得到輸出值;
        double[] output = new double[out_w.length + 1];
        forword(hide1_x, output);

        // 2.3、方向傳播:
        backpropagation(output);

    }

    /**
     * 反向傳播過程
     * 
     * @param output
     *            預測結果
     */
    public void backpropagation(double[] output) {

        // 2.3.1、獲取輸出層的誤差;
        get_out_error(output, target, out_errors);
        // 2.3.2、獲取隱含層的誤差;
        get_hide_error(out_errors, out_w, out_x, hide1_errors);
        //// 2.3.3、更新隱含層的權值;
        update_weight(hide1_errors, hide1_w, hide1_x);
        // * 2.3.4、更新輸出層的權值;
        update_weight(out_errors, out_w, out_x);
    }

    /**
     * 預測
     * 
     * @param data
     *            預測資料
     * @param output
     *            輸出值
     */
    public void predict(double[] data, double[] output) {

        double[] out_y = new double[out_w.length + 1];
        setHide1_x(data);
        forword(hide1_x, out_y);
        System.arraycopy(out_y, 1, output, 0, output.length);

    }

    
    public void update_weight(double[] err, double[][] w, double[] x) {

        double newweight = 0.0;
        for (int i = 0; i < w.length; i++) {
            for (int j = 0; j < w[i].length; j++) {
                newweight = rate * err[i] * x[j];
                w[i][j] = w[i][j] + newweight;
            }

        }
    }

    /**
     * 獲取輸出層的誤差
     * 
     * @param output
     *            預測輸出值
     * @param target
     *            目標值
     * @param out_error
     *            輸出層的誤差
     */
    public void get_out_error(double[] output, double[] target, double[] out_error) {
        for (int i = 0; i < target.length; i++) {
            out_error[i] = (target[i] - output[i + 1]) * output[i + 1] * (1d - output[i + 1]);
        }

    }

    /**
     * 獲取隱含層的誤差
     * 
     * @param NeLaErr
     *            下一層的誤差
     * @param Nextw
     *            下一層的權值
     * @param output 下一層的輸入
     * @param error
     *            本層誤差陣列
     */
    public void get_hide_error(double[] NeLaErr, double[][] Nextw, double[] output, double[] error) {

        for (int k = 0; k < error.length; k++) {
            double sum = 0;
            for (int j = 0; j < Nextw.length; j++) {
                sum += Nextw[j][k + 1] * NeLaErr[j];
            }
            error[k] = sum * output[k + 1] * (1d - output[k + 1]);
        }
    }

    /**
     * 向前傳播
     * 
     * @param x
     *            輸入值
     * @param output
     *            輸出值
     */
    public void forword(double[] x, double[] output) {

        // 2.2.1、獲取隱含層的輸出
        get_net_out(x, hide1_w, out_x);
        // 2.2.2、獲取輸出層的輸出
        get_net_out(out_x, out_w, output);

    }

    /**
     * 獲取單個節點的輸出
     * 
     * @param x
     *            輸入矩陣
     * @param w
     *            權值
     * @return 輸出值
     */
    private double get_node_put(double[] x, double[] w) {
        double z = 0d;

        for (int i = 0; i < x.length; i++) {
            z += x[i] * w[i];
        }
        // 2.激勵函式
        return 1d / (1d + Math.exp(-z));
    }

    /**
     * 獲取網路層的輸出
     * 
     * @param x
     *            輸入矩陣
     * @param w
     *            權值矩陣
     * @param net_out
     *            接收網路層的輸出陣列
     */
    private void get_net_out(double[] x, double[][] w, double[] net_out) {

        net_out[0] = 1d;
        for (int i = 0; i < w.length; i++) {
            net_out[i + 1] = get_node_put(x, w[i]);
        }

    }

}

(二) BP神經網路的測試

用上面實現的BP神經網路來訓練模型,自動判斷它是正數還是複數,奇數還是偶數.

package com.fei.bp02;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class Test {

    /**
     * @param args
     * @throws IOException
     */
    public static void main(String[] args) throws IOException {
        
    
        Bp bp = new Bp(32, 15, 4, 0.05);

        Random random = new Random();
        
        List<Integer> list = new ArrayList<Integer>();
        for (int i = 0; i < 6000; i++) {
            int value = random.nextInt(1000);//1000內的隨機數
            list.add(value);
            list.add(0-value);
        }

        
        for (int i = 0; i !=25; i++) {
            for (int value : list) {
                double[] real = new double[4];
                if (value >= 0)
                    if ((value & 1) == 1)
                        real[0] = 1;
                    else
                        real[1] = 1;
                else if ((value & 1) == 1)
                    real[2] = 1;
                else
                    real[3] = 1;
                
                double[] binary = new double[32];
                int index = 31;
                do {
                    binary[index--] = (value & 1);
                    value >>>= 1;
                } while (value != 0);

                bp.train(binary, real);
               
                

            }
        }
        

        
        
        System.out.println("訓練完畢,下面請輸入一個任意數字(-1000--1000),神經網路將自動判斷它是正數還是複數,奇數還是偶數。");

        while (true) {
            
            byte[] input = new byte[10];
            System.in.read(input);
            Integer value = Integer.parseInt(new String(input).trim());
            int rawVal = value;
            double[] binary = new double[32];
            int index = 31;
            do {
                binary[index--] = (value & 1);
                value >>>= 1;
            } while (value != 0);

            double[] result =new double[4];
             bp.predict(binary,result);

             
            double max = -Integer.MIN_VALUE;
            int idx = -1;

            for (int i = 0; i != result.length; i++) {
                if (result[i] > max) {
                    max = result[i];
                    idx = i;
                }
            }

            switch (idx) {
            case 0:
                System.out.format("%d是一個正奇數\n", rawVal);
                break;
            case 1:
                System.out.format("%d是一個正偶數\n", rawVal);
                break;
            case 2:
                System.out.format("%d是一個負奇數\n", rawVal);
                break;
            case 3:
                System.out.format("%d是一個負偶數\n", rawVal);
                break;
            }
        }
    }
}