1. 程式人生 > >【深度學習】BP演算法分類iris資料集

【深度學習】BP演算法分類iris資料集

這裡寫圖片描述
Network:

package test2;

import java.util.Random;

public class Network {

    private double input[]; // 輸入層
    private double hidden[]; // 隱藏層
    private double output[]; // 輸出層
    private double target[]; // 期望輸出向量
    private double i_h_weight[][]; // 輸入層-隱藏層權值
    private double h_o_weight[][]; // 隱藏層-輸出層權值
private double i_h_weightUpdate[][]; // 輸入層權值更新 private double h_o_weightUpdate[][]; // 輸出層權值更新 private double outputError[];// 輸出層誤差 private double hiddenError[];// 隱藏層誤差 private double outputErrorSum;// 輸出誤差和 private double hiddenErrorSum;// 隱藏誤差和 // private double i_threshold[]; // 輸入層閾值
// private double o_threshold[]; // 輸出層閾值 private double rate = 0.25; private double momentum = 0.3; private Random random; /** * 初始化 * @param inputSize * @param hiddenSize * @param outputSize */ public void init(int inputSize, int hiddenSize, int outputSize) { input = new
double[inputSize]; hidden = new double[hiddenSize]; output = new double[outputSize]; target = new double[outputSize]; i_h_weight = new double[inputSize][hiddenSize]; h_o_weight = new double[hiddenSize][outputSize]; i_h_weightUpdate = new double[inputSize][hiddenSize]; h_o_weightUpdate = new double[hiddenSize][outputSize]; outputError = new double[outputSize]; hiddenError = new double[hiddenSize]; rate = 0.2; momentum = 0.3; random = new Random(); randomWeights(i_h_weight); randomWeights(h_o_weight); } /** * 隨機權值 * @param matrix */ private void randomWeights(double[][] matrix) { for (int i = 0; i < matrix.length; i++) for (int j = 0; j < matrix[i].length; j++) { double real = random.nextDouble(); matrix[i][j] = real > 0.5 ? real : -real; } } /** * 訓練 * @param trainData * @param target */ public void train(double[] trainData, double[] target) { loadInput(trainData); loadTarget(target); forward(); calculateError(); adjustWeight(); } /** * 測試 * @param inData * @return */ public double[] test(double[] inData) { if (inData.length != input.length) { throw new IllegalArgumentException("長度不匹配."); } input = inData; forward(); return getNetworkOutput(); } /** * 網路輸出 * @return */ private double[] getNetworkOutput() { int len = output.length; double[] temp = new double[len]; for (int i = 0; i != len; i++) temp[i] = output[i]; return temp; } /** * 載入期望資料 * @param target */ private void loadTarget(double target[]) { if (this.target.length != target.length) { throw new IllegalArgumentException("長度不匹配."); } this.target = target; } /** * 載入輸入資料 * @param input */ private void loadInput(double input[]) { if (this.input.length != input.length) { throw new IllegalArgumentException("長度不匹配."); } this.input = input; } /** * 前向傳播 * @param layer0 * @param layer1 * @param weight */ private void forward(double[] layer0, double[] layer1, double[][] weight) { for (int j = 0; j < layer1.length; j++) { double sum = 0; for (int i = 0; i < layer0.length; i++) sum += weight[i][j] * layer0[i]; layer1[j] = sigmoid(sum); } } /** * 前向傳播 */ public void forward() { forward(input, hidden, i_h_weight); forward(hidden, output, h_o_weight); } /** * 輸出層誤差 */ private void outputError() { double errSum = 0; for (int i = 0; i < outputError.length; i++) { double o = output[i]; outputError[i] = o * (1d - o) * (target[i] - o);// 誤差函式 errSum += Math.abs(outputError[i]); } outputErrorSum = errSum; } /** * 隱含層誤差 */ private void hiddenError() { double errSum = 0; for (int i = 0; i < hiddenError.length; i++) { double o = hidden[i]; double sum = 0; for (int j = 0; j < outputError.length; j++) sum += h_o_weight[i][j] * outputError[j]; hiddenError[i] = o * (1d - o) * sum; errSum += Math.abs(hiddenError[i]); } hiddenErrorSum = errSum; } /** * 計算誤差 */ private void calculateError() { outputError(); hiddenError(); } /** * 調整權值 * @param error * @param layer * @param weight * @param prevWeight */ private void adjustWeight(double[] error, double[] layer, double[][] weight, double[][] prevWeight) { // layer[0] = 1; for (int i = 0; i < error.length; i++) { for (int j = 0; j < layer.length; j++) { double newVal = momentum * prevWeight[j][i] + rate * error[i] * layer[j]; weight[j][i] += newVal; prevWeight[j][i] = newVal; } } } /** * 調整權值 */ private void adjustWeight() { adjustWeight(hiddenError, input, i_h_weight, i_h_weightUpdate);// 15,15,(4,15),(4,15) adjustWeight(outputError, hidden, h_o_weight, h_o_weightUpdate); } /** * 啟用函式,輸出區間(0,1),關於(0,0.5)中心對稱 * * @param x * @return */ public double sigmoid(double x) { return 1 / (1 + Math.exp(-x)); } /** * 啟用函式,輸出區間(-1,1),關於(0,0)中心對稱 * * @param x * @return */ public double tanh(double x) { return (1 - Math.exp(-2 * x)) / (1 + Math.exp(-2 * x)); } }

Mian:

package test2;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import test.BP;

public class Main {

    public static void main(String[] args) throws IOException {
        System.out.println("->讀取樣本資料");
        ReadData rd = new ReadData();
        List<double[]> data = rd.loadData("data/iris.txt", 0, 3, ",");
        System.out.println("->讀取完成");
        System.out.println("->初始化神經網路");
        int ipt = 4;
        int opt = 3;
        int hid = (int) (Math.sqrt(ipt + opt) + 10);
        Network bp = new Network();
        bp.init(ipt, hid, opt);
        System.out.println("->初始化完成");
        int maxLearn = 10000;
        System.out.println("->最大學習次數:" + maxLearn);
        System.out.println("->開始訓練");
        double start = System.currentTimeMillis();
        for (int j = 0; j < maxLearn; j++) {
            for (int i = 0; i < data.size(); i++) {
                double[] target = new double[] { 0, 0, 0 };
                if (i < 50)
                    target[0] = 1;
                else if (i < 100)
                    target[1] = 1;
                else if (i < 150)
                    target[2] = 1;
                bp.train(data.get(i), target);
            }
        }
        double end = System.currentTimeMillis();
        System.out.println("->訓練完成,用時:" + (end - start) + "ms");

        System.out.println("-------------");
        List<double[]> testData = rd.loadData("data/test.txt", 0, 3, ",");
        int correct = 0;
        int error = 0;
        for (int i = 0; i < testData.size(); i++) {
            double[] result = bp.test(testData.get(i));
            // System.out.println("-------------");
            // System.out.println("->網路輸出:"+Arrays.toString(result));
            // System.out.println("->分類結果:"+classify(result));
            if (classify(result).equals(rd.getColumn("data/test.txt", 4, ",").get(i))) {
                // System.out.println("->分類結果:√");
                correct++;
            } else {
                // System.out.println("->分類結果:×");
                error++;
            }
        }
        System.out.println("->測試資料:" + (correct + error) + "條," + "正確 " + correct + "條");
        System.out.println("->正確率:" + (float) correct / (correct + error));
    }

    private static String classify(double[] result) {
        String[] category = { "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
        String resStr = "";
        double max = -Integer.MIN_VALUE;
        int idx = -1;
        for (int i = 0; i != result.length; i++) {
            if (result[i] > max) {
                max = result[i];
                idx = i;
            }
        }
        switch (idx) {
        case 0:
            resStr = category[0];
            break;
        case 1:
            resStr = category[1];
            break;
        case 2:
            resStr = category[2];
        default:
            break;
        }
        return resStr;
    }

}

結果:
這裡寫圖片描述