1. 程式人生 > >deeplearning4j實現多感知器的手寫數字識別

deeplearning4j實現多感知器的手寫數字識別

package com.itcast.wang.test_dl;

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.deeplearning4j.eval.Evaluation;  
import org.nd4j.linalg.api.ndarray.INDArray;
/**A Simple MLP applied to digit classification for MNIST.
 */
public class MLPMnistSingleLayerExample {
 
    private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);
 
    public static void main(String[] args) throws Exception {
 
        final int numRows = 28;//影象寬
        final int numColumns = 28;//影象長
        int outputNum = 10;//輸出的類別數
        int batchSize = 128;//沒128個樣本參加訓練
        int rngSeed = 123;//
        int numEpochs = 15;//訓練集樣本迭代的次數
 
        //Get the DataSetIterators:
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
        DataSetIterator mnist = new MnistDataSetIterator(batchSize, false, rngSeed);
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)//隨機梯度下降
                .iterations(1)
                .learningRate(0.006)//學習率
                .updater(Updater.NESTEROVS).momentum(0.9)//運動慣量
                .regularization(true).l2(1e-4)//是否使用正則化
                .list()
                .layer(0, new DenseLayer.Builder()//第一層網路配置
                        .nIn(numRows * numColumns)//輸入數目
                        .nOut(1000)//輸出數目
                        .activation("relu")//啟用函式 relu
                        .weightInit(WeightInit.XAVIER)//權值初始化
                        .build())
                //輸出層指定誤差函式
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)//誤差函式
                        .nIn(1000)//輸入
                        .nOut(outputNum)//輸出
                        .activation("softmax")//啟用函式
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .pretrain(false).backprop(true)
                .build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));
        log.info("Train model....");
        for(int i=0;i<4;i++){  
            model.fit(mnistTrain);  
            System.out.println(" Completed epoch is :" + i);  
  
                System.out.println("Evaluate model....");  
                Evaluation eval = new Evaluation(outputNum);  
                while(mnist.hasNext()){  
                    DataSet ds = mnist.next();  
                    INDArray output = model.output(ds.getFeatureMatrix(), false);  
                    eval.eval(ds.getLabels(), output);  
                }  
                System.out.println(eval.stats());  
                mnist.reset();  
              
        }  
        System.out.println("model finish");  
    }
    }