1. 程式人生 > >6、神經網路學習總結

6、神經網路學習總結

package com.jd;

import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;

//        輸入:
//            訓練集 D = {(x1,y1),(x2,y2),...,(xm,ym)};
//            屬性集 A = {a1,a2,...,ad}.
//        過程:
//            在(0,1)範圍內隨機初始化網路中所有連線權和閾值
//            repeat
//                for all(xk,yk)屬於D do
//                    根據當前引數和式(5.3)計算當前樣本的輸出yk';
//                    根據式(5.10)計算輸出層神經元的梯度項gj;
//                    根據式(5.15)計算隱層神經元的梯度項eh;
//                    根據式(5.11)-(5.14)更新連線權whj,vih與閾值oj,yh
//                end for
//            until達到停止條件
//
//        輸出:連線權與閾值確定的多層前饋神經網路


public class neural_network {

    static List<Double> L_copy(List<Double> L){
        List<Double> L_new = new ArrayList<Double>();
        for(Double num:L){
            L_new.add(num);
        }
        return L_new;
    }

    static List<List<Double>> LL_copy(List<List<Double>> LL){
        List<List<Double>> LL_new = new ArrayList<List<Double>>();
        for(List<Double> L :LL){
            LL_new.add(L_copy(L));
        }
        return LL_new;
    }

    static class neural_network_parameter{

        public List<List<Double>> LL01;
        public List<Double> cut_off01;
        public List<List<Double>> LL12;
        public List<Double> cut_off12;

        public neural_network_parameter(List<List<Double>> LL01, List<Double> cut_off01,
                                        List<List<Double>> LL12, List<Double> cut_off12){
            this.LL01 = LL_copy(LL01);
            this.LL12 = LL_copy(LL12);
            this.cut_off01 = L_copy(cut_off01);
            this.cut_off12 = L_copy(cut_off12);
        }

    }

    //資料讀入
    static List<play_example> exampleListMapInit(BufferedReader lines) throws Exception{

        List<play_example> exampleList = new ArrayList<play_example>();

        String line = null;

        //空出第一行
        lines.readLine();

        while ((line = lines.readLine())!=null) {

            String[] lineArray = line.split(",");

            exampleList.add(new play_example(lineArray[0],lineArray[1],lineArray[2],lineArray[3],lineArray[4]));

        }

        return exampleList;
    }

    static List<List<Double>> weight_matrix_generation(int num1, int num2){

        List<List<Double>> LL = new ArrayList<List<Double>>();

        for(int i=0; i<num1; i++){
            List<Double> L = new ArrayList<Double>();
            for(int j=0; j<num2; j++){
//                L.add((double)(i+j));
                L.add(Math.random()-0.5);
            }
            LL.add(L);
        }

        return LL;
    }

    static List<Double> cut_off_generation(int num1){

        List<Double> L = new ArrayList<Double>();

        for(int i=0; i<num1; i++){
            L.add(Math.random()-0.5);
        }

        return L;
    }

    static  List<List<Double>> data_pretreatment(List<play_example> exampleList){

        List<List<Double>> data = new ArrayList<List<Double>>();

        for(play_example example: exampleList){
            List<Double> L = new ArrayList<Double>();
            L.add(example.outlook);
            L.add(example.temperature);
            L.add(example.humidity);
            L.add(example.windy);
            L.add(example.is_play);
            data.add(L);
        }

        return data;
    }

    static Double activation(Double cell){
        return 1/(1+Math.exp(-cell));
    }

    static List<List<Double>> spread(List<List<Double>> datai, List<List<Double>> LLi_1, List<Double> cut_offi_1){

        List<List<Double>> datai_1 = new ArrayList<List<Double>>();

        for(int i=0; i<datai.size(); i++){
            List<Double> L = new ArrayList<Double>();
            for(int j=0; j<LLi_1.get(0).size(); j++){
                double cell = 0;
                for(int k=0; k<datai.get(0).size(); k++){
                    cell += datai.get(i).get(k)* LLi_1.get(k).get(j);
                }
                L.add(activation(cell-cut_offi_1.get(j)));
            }
            datai_1.add(L);
        }

        return datai_1;
    }


    static Double error(List<List<Double>> X, List<List<Double>> LL01, List<Double> cut_off01,
                        List<List<Double>> LL12, List<Double> cut_off12, List<Double> Y){

        List<List<Double>> data1 = spread(X,LL01,cut_off01);

        List<List<Double>> data2 = spread(data1,LL12,cut_off12);

        double error = 0;
        for(int i=0; i<Y.size(); i++){
            error += Math.pow(Y.get(i)-data2.get(i).get(0),2);
        }

        return error;
    }

    static void random_gradient_iteration(List<List<Double>> X, List<List<Double>> LL01, List<Double> cut_off01,
                                   List<List<Double>> LL12, List<Double> cut_off12, List<Double> Y, double rate, int idex){

        List<List<Double>> data1 = spread(X,LL01,cut_off01);
        List<List<Double>> data2 = spread(data1,LL12,cut_off12);

        List<List<Double>> dLL01 = new ArrayList<List<Double>>();
        List<List<Double>> dLL12 = new ArrayList<List<Double>>();
        List<Double> dcut_off01 = new ArrayList<Double>();
        List<Double> dcut_off12 = new ArrayList<Double>();

        for(int i=0; i<cut_off12.size(); i++){
            dcut_off12.add(rate*(data2.get(idex).get(0)-Y.get(idex))*data2.get(idex).get(0)*(1-data2.get(idex).get(0)));
        }

        for(int i=0; i<LL12.size(); i++) {
            List<Double> L = new ArrayList<Double>();
            for (int j = 0; j < LL12.get(i).size(); j++) {
                L.add(-dcut_off12.get(j)*data1.get(idex).get(i));
            }
            dLL12.add(L);
        }

        for(int i=0; i<cut_off01.size(); i++){
            double sum = 0;
            for(int j=0; j< LL12.get(i).size(); j++){
                sum += data1.get(idex).get(i)*(1-data1.get(idex).get(i))*cut_off12.get(j)*LL12.get(i).get(j);
            }
            dcut_off01.add(sum);
        }

        for(int i=0; i<LL01.size(); i++){
            List<Double> L = new ArrayList<Double>();
            for(int j=0; j<LL01.get(i).size(); j++){
                L.add(-dcut_off01.get(j)*X.get(idex).get(i));
            }
            dLL01.add(L);
        }

        for(int i=0; i<LL01.size(); i++){
            for(int j=0; j<LL01.get(i).size(); j++){
                LL01.get(i).set(j,LL01.get(i).get(j)+dLL01.get(i).get(j));
            }
        }

        for(int i=0; i<LL12.size(); i++){
            for(int j=0; j<LL12.get(i).size(); j++){
                LL12.get(i).set(j,LL12.get(i).get(j)+dLL12.get(i).get(j));
            }
        }

        for(int i=0; i<cut_off01.size(); i++){
            cut_off01.set(i,cut_off01.get(i)+dcut_off01.get(i));
        }

        for(int i=0; i<cut_off12.size(); i++){
            cut_off12.set(i,cut_off12.get(i)+dcut_off12.get(i));
        }
    }

    static void BP(List<List<Double>> data){

        //這裡為單隱層的神經網路,設定輸入維度,隱層維度,輸出維度
        int d = data.get(0).size()-1;
        int q = 4;
        int l = 1;

        List<List<Double>> X = new ArrayList<List<Double>>();
        List<Double> Y = new ArrayList<Double>();

        for(int i=0; i<data.size(); i++){
            List<Double> L = new ArrayList<Double>();
            for(int j=0; j<data.get(i).size()-1; j++){
                L.add(data.get(i).get(j));
            }
            X.add(L);
            Y.add(data.get(i).get(data.get(i).size()-1));

        }


        //隨機設定初始的權重向量(以後面的數量為準),例如(d,q)就有q組向量,每個向量d維
        //這裡如果想提高效率,可以隨機多組初始值,同時執行
        List<List<Double>> LL01 = weight_matrix_generation(d,q);
        List<Double> cut_off01 = cut_off_generation(q);
        List<List<Double>> LL12 = weight_matrix_generation(q,l);
        List<Double> cut_off12 = cut_off_generation(1);

        neural_network_parameter parameter = new neural_network_parameter(LL01,cut_off01,LL12,cut_off12);
        double error = error(X, LL01, cut_off01, LL12, cut_off12, Y);

        double rate = 0.1;
        for(int idex=0; idex<1000; idex++){
            random_gradient_iteration(X, LL01, cut_off01, LL12, cut_off12, Y, rate, idex%Y.size());
            if(error(X, LL01, cut_off01, LL12, cut_off12, Y)<error){
                error = error(X, LL01, cut_off01, LL12, cut_off12, Y);
                parameter = new neural_network_parameter(LL01,cut_off01,LL12,cut_off12);
                System.out.println(error);
            }
        }

    }


    public static void main(String[] args) throws Exception {

        BufferedReader lines = new BufferedReader(new FileReader(
                "C:\\Users\\zhangchaoyu\\Desktop\\zcy\\java_and_scala\\machine_learning\\src\\main\\resources\\14.csv"));
//        BufferedReader lines = new BufferedReader(new FileReader("D:\\java專案\\machine_learning\\src\\main\\resources\\1.csv"));

        List<play_example> exampleList = exampleListMapInit(lines);

        List<List<Double>> data = data_pretreatment(exampleList);

        BP(data);

    }



}