1. 程式人生 > >mxnet卷積神經網路訓練MNIST資料集測試

mxnet卷積神經網路訓練MNIST資料集測試

import numpy as np
import mxnet as mx
import logging

logging.getLogger().setLevel(logging.DEBUG)

batch_size = 100
mnist = mx.test_utils.get_mnist()
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

data = mx.sym.var('data') 
# first conv layer
conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten= mx.sym.Flatten(data=pool2)
fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')

# create a trainable module on GPU 0
lenet_model = mx.mod.Module(
                symbol=lenet, 
                context=mx.cpu())

# train with the same
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='sgd',
                optimizer_params={'learning_rate':0.1},
                eval_metric='acc',
                batch_end_callback = mx.callback.Speedometer(batch_size, 100),
                num_epoch=10)

顯示結果:

INFO:root:train-labels-idx1-ubyte.gz exists, skip to downloada
INFO:root:train-images-idx3-ubyte.gz exists, skip to downloada
INFO:root:t10k-labels-idx1-ubyte.gz exists, skip to downloada
INFO:root:t10k-images-idx3-ubyte.gz exists, skip to downloada
INFO:root:Epoch[0] Batch [100]  Speed: 722.13 samples/sec       accuracy=0.103366
INFO:root:Epoch[0] Batch [200]  Speed: 713.60 samples/sec       accuracy=0.115500
INFO:root:Epoch[0] Batch [300]  Speed: 714.94 samples/sec       accuracy=0.110900
INFO:root:Epoch[0] Batch [400]  Speed: 709.44 samples/sec       accuracy=0.111200
INFO:root:Epoch[0] Batch [500]  Speed: 714.26 samples/sec       accuracy=0.114600
INFO:root:Epoch[0] Train-accuracy=0.113434
INFO:root:Epoch[0] Time cost=83.928
INFO:root:Epoch[0] Validation-accuracy=0.113500
INFO:root:Epoch[1] Batch [100]  Speed: 716.48 samples/sec       accuracy=0.161683
INFO:root:Epoch[1] Batch [200]  Speed: 675.00 samples/sec       accuracy=0.591100
INFO:root:Epoch[1] Batch [300]  Speed: 668.75 samples/sec       accuracy=0.861500
INFO:root:Epoch[1] Batch [400]  Speed: 647.97 samples/sec       accuracy=0.899400
INFO:root:Epoch[1] Batch [500]  Speed: 666.97 samples/sec       accuracy=0.920600
INFO:root:Epoch[1] Train-accuracy=0.932828
INFO:root:Epoch[1] Time cost=88.947
INFO:root:Epoch[1] Validation-accuracy=0.940800
INFO:root:Epoch[2] Batch [100]  Speed: 660.08 samples/sec       accuracy=0.944653
INFO:root:Epoch[2] Batch [200]  Speed: 650.96 samples/sec       accuracy=0.954200
INFO:root:Epoch[2] Batch [300]  Speed: 669.57 samples/sec       accuracy=0.958800
INFO:root:Epoch[2] Batch [400]  Speed: 644.97 samples/sec       accuracy=0.963200
INFO:root:Epoch[2] Batch [500]  Speed: 654.75 samples/sec       accuracy=0.967100
INFO:root:Epoch[2] Train-accuracy=0.969394
INFO:root:Epoch[2] Time cost=91.671
INFO:root:Epoch[2] Validation-accuracy=0.973100
INFO:root:Epoch[3] Batch [100]  Speed: 660.64 samples/sec       accuracy=0.970990
INFO:root:Epoch[3] Batch [200]  Speed: 669.49 samples/sec       accuracy=0.974400
INFO:root:Epoch[3] Batch [300]  Speed: 650.88 samples/sec       accuracy=0.973900
INFO:root:Epoch[3] Batch [400]  Speed: 665.29 samples/sec       accuracy=0.976800
INFO:root:Epoch[3] Batch [500]  Speed: 664.31 samples/sec       accuracy=0.976000
INFO:root:Epoch[3] Train-accuracy=0.978384
INFO:root:Epoch[3] Time cost=90.576
INFO:root:Epoch[3] Validation-accuracy=0.981600
INFO:root:Epoch[4] Batch [100]  Speed: 657.94 samples/sec       accuracy=0.978416
INFO:root:Epoch[4] Batch [200]  Speed: 651.82 samples/sec       accuracy=0.980100
INFO:root:Epoch[4] Batch [300]  Speed: 653.96 samples/sec       accuracy=0.982100
INFO:root:Epoch[4] Batch [400]  Speed: 647.17 samples/sec       accuracy=0.982400
INFO:root:Epoch[4] Batch [500]  Speed: 656.77 samples/sec       accuracy=0.981900
INFO:root:Epoch[4] Train-accuracy=0.984646
INFO:root:Epoch[4] Time cost=91.804
INFO:root:Epoch[4] Validation-accuracy=0.983400
INFO:root:Epoch[5] Batch [100]  Speed: 649.50 samples/sec       accuracy=0.983069
INFO:root:Epoch[5] Batch [200]  Speed: 649.20 samples/sec       accuracy=0.984600
INFO:root:Epoch[5] Batch [300]  Speed: 647.68 samples/sec       accuracy=0.985200
INFO:root:Epoch[5] Batch [400]  Speed: 658.71 samples/sec       accuracy=0.985900
INFO:root:Epoch[5] Batch [500]  Speed: 646.41 samples/sec       accuracy=0.984900
INFO:root:Epoch[5] Train-accuracy=0.987071
INFO:root:Epoch[5] Time cost=92.219
INFO:root:Epoch[5] Validation-accuracy=0.985100
INFO:root:Epoch[6] Batch [100]  Speed: 645.74 samples/sec       accuracy=0.985842
INFO:root:Epoch[6] Batch [200]  Speed: 653.40 samples/sec       accuracy=0.987800
INFO:root:Epoch[6] Batch [300]  Speed: 646.12 samples/sec       accuracy=0.987800
INFO:root:Epoch[6] Batch [400]  Speed: 641.82 samples/sec       accuracy=0.988100
INFO:root:Epoch[6] Batch [500]  Speed: 643.05 samples/sec       accuracy=0.986900
INFO:root:Epoch[6] Train-accuracy=0.989192
INFO:root:Epoch[6] Time cost=96.044
INFO:root:Epoch[6] Validation-accuracy=0.986100
INFO:root:Epoch[7] Batch [100]  Speed: 653.00 samples/sec       accuracy=0.987327
INFO:root:Epoch[7] Batch [200]  Speed: 650.61 samples/sec       accuracy=0.988800
INFO:root:Epoch[7] Batch [300]  Speed: 649.02 samples/sec       accuracy=0.989100
INFO:root:Epoch[7] Batch [400]  Speed: 644.93 samples/sec       accuracy=0.990000
INFO:root:Epoch[7] Batch [500]  Speed: 554.87 samples/sec       accuracy=0.988700
INFO:root:Epoch[7] Train-accuracy=0.990202
INFO:root:Epoch[7] Time cost=94.743
INFO:root:Epoch[7] Validation-accuracy=0.987600
INFO:root:Epoch[8] Batch [100]  Speed: 649.92 samples/sec       accuracy=0.988812
INFO:root:Epoch[8] Batch [200]  Speed: 654.07 samples/sec       accuracy=0.990800
INFO:root:Epoch[8] Batch [300]  Speed: 656.73 samples/sec       accuracy=0.990700
INFO:root:Epoch[8] Batch [400]  Speed: 653.70 samples/sec       accuracy=0.990900
INFO:root:Epoch[8] Batch [500]  Speed: 631.36 samples/sec       accuracy=0.990200
INFO:root:Epoch[8] Train-accuracy=0.991616
INFO:root:Epoch[8] Time cost=92.349
INFO:root:Epoch[8] Validation-accuracy=0.988500
INFO:root:Epoch[9] Batch [100]  Speed: 647.88 samples/sec       accuracy=0.990792
INFO:root:Epoch[9] Batch [200]  Speed: 635.89 samples/sec       accuracy=0.991900
INFO:root:Epoch[9] Batch [300]  Speed: 637.18 samples/sec       accuracy=0.991700
INFO:root:Epoch[9] Batch [400]  Speed: 640.23 samples/sec       accuracy=0.992300
INFO:root:Epoch[9] Batch [500]  Speed: 640.93 samples/sec       accuracy=0.991900
INFO:root:Epoch[9] Train-accuracy=0.992828
INFO:root:Epoch[9] Time cost=93.533
INFO:root:Epoch[9] Validation-accuracy=0.988700