mxnet卷積神經網路訓練MNIST資料集測試
阿新 • • 發佈:2019-02-07
import numpy as np import mxnet as mx import logging logging.getLogger().setLevel(logging.DEBUG) batch_size = 100 mnist = mx.test_utils.get_mnist() train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) data = mx.sym.var('data') # first conv layer conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20) tanh1= mx.sym.Activation(data=conv1, act_type="tanh") pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2)) # second conv layer conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50) tanh2= mx.sym.Activation(data=conv2, act_type="tanh") pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2)) # first fullc layer flatten= mx.sym.Flatten(data=pool2) fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500) tanh3= mx.sym.Activation(data=fc1, act_type="tanh") # second fullc fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10) # softmax loss lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax') # create a trainable module on GPU 0 lenet_model = mx.mod.Module( symbol=lenet, context=mx.cpu()) # train with the same lenet_model.fit(train_iter, eval_data=val_iter, optimizer='sgd', optimizer_params={'learning_rate':0.1}, eval_metric='acc', batch_end_callback = mx.callback.Speedometer(batch_size, 100), num_epoch=10)
顯示結果:
INFO:root:train-labels-idx1-ubyte.gz exists, skip to downloada INFO:root:train-images-idx3-ubyte.gz exists, skip to downloada INFO:root:t10k-labels-idx1-ubyte.gz exists, skip to downloada INFO:root:t10k-images-idx3-ubyte.gz exists, skip to downloada INFO:root:Epoch[0] Batch [100] Speed: 722.13 samples/sec accuracy=0.103366 INFO:root:Epoch[0] Batch [200] Speed: 713.60 samples/sec accuracy=0.115500 INFO:root:Epoch[0] Batch [300] Speed: 714.94 samples/sec accuracy=0.110900 INFO:root:Epoch[0] Batch [400] Speed: 709.44 samples/sec accuracy=0.111200 INFO:root:Epoch[0] Batch [500] Speed: 714.26 samples/sec accuracy=0.114600 INFO:root:Epoch[0] Train-accuracy=0.113434 INFO:root:Epoch[0] Time cost=83.928 INFO:root:Epoch[0] Validation-accuracy=0.113500 INFO:root:Epoch[1] Batch [100] Speed: 716.48 samples/sec accuracy=0.161683 INFO:root:Epoch[1] Batch [200] Speed: 675.00 samples/sec accuracy=0.591100 INFO:root:Epoch[1] Batch [300] Speed: 668.75 samples/sec accuracy=0.861500 INFO:root:Epoch[1] Batch [400] Speed: 647.97 samples/sec accuracy=0.899400 INFO:root:Epoch[1] Batch [500] Speed: 666.97 samples/sec accuracy=0.920600 INFO:root:Epoch[1] Train-accuracy=0.932828 INFO:root:Epoch[1] Time cost=88.947 INFO:root:Epoch[1] Validation-accuracy=0.940800 INFO:root:Epoch[2] Batch [100] Speed: 660.08 samples/sec accuracy=0.944653 INFO:root:Epoch[2] Batch [200] Speed: 650.96 samples/sec accuracy=0.954200 INFO:root:Epoch[2] Batch [300] Speed: 669.57 samples/sec accuracy=0.958800 INFO:root:Epoch[2] Batch [400] Speed: 644.97 samples/sec accuracy=0.963200 INFO:root:Epoch[2] Batch [500] Speed: 654.75 samples/sec accuracy=0.967100 INFO:root:Epoch[2] Train-accuracy=0.969394 INFO:root:Epoch[2] Time cost=91.671 INFO:root:Epoch[2] Validation-accuracy=0.973100 INFO:root:Epoch[3] Batch [100] Speed: 660.64 samples/sec accuracy=0.970990 INFO:root:Epoch[3] Batch [200] Speed: 669.49 samples/sec accuracy=0.974400 INFO:root:Epoch[3] Batch [300] Speed: 650.88 samples/sec accuracy=0.973900 INFO:root:Epoch[3] Batch [400] Speed: 665.29 samples/sec accuracy=0.976800 INFO:root:Epoch[3] Batch [500] Speed: 664.31 samples/sec accuracy=0.976000 INFO:root:Epoch[3] Train-accuracy=0.978384 INFO:root:Epoch[3] Time cost=90.576 INFO:root:Epoch[3] Validation-accuracy=0.981600 INFO:root:Epoch[4] Batch [100] Speed: 657.94 samples/sec accuracy=0.978416 INFO:root:Epoch[4] Batch [200] Speed: 651.82 samples/sec accuracy=0.980100 INFO:root:Epoch[4] Batch [300] Speed: 653.96 samples/sec accuracy=0.982100 INFO:root:Epoch[4] Batch [400] Speed: 647.17 samples/sec accuracy=0.982400 INFO:root:Epoch[4] Batch [500] Speed: 656.77 samples/sec accuracy=0.981900 INFO:root:Epoch[4] Train-accuracy=0.984646 INFO:root:Epoch[4] Time cost=91.804 INFO:root:Epoch[4] Validation-accuracy=0.983400 INFO:root:Epoch[5] Batch [100] Speed: 649.50 samples/sec accuracy=0.983069 INFO:root:Epoch[5] Batch [200] Speed: 649.20 samples/sec accuracy=0.984600 INFO:root:Epoch[5] Batch [300] Speed: 647.68 samples/sec accuracy=0.985200 INFO:root:Epoch[5] Batch [400] Speed: 658.71 samples/sec accuracy=0.985900 INFO:root:Epoch[5] Batch [500] Speed: 646.41 samples/sec accuracy=0.984900 INFO:root:Epoch[5] Train-accuracy=0.987071 INFO:root:Epoch[5] Time cost=92.219 INFO:root:Epoch[5] Validation-accuracy=0.985100 INFO:root:Epoch[6] Batch [100] Speed: 645.74 samples/sec accuracy=0.985842 INFO:root:Epoch[6] Batch [200] Speed: 653.40 samples/sec accuracy=0.987800 INFO:root:Epoch[6] Batch [300] Speed: 646.12 samples/sec accuracy=0.987800 INFO:root:Epoch[6] Batch [400] Speed: 641.82 samples/sec accuracy=0.988100 INFO:root:Epoch[6] Batch [500] Speed: 643.05 samples/sec accuracy=0.986900 INFO:root:Epoch[6] Train-accuracy=0.989192 INFO:root:Epoch[6] Time cost=96.044 INFO:root:Epoch[6] Validation-accuracy=0.986100 INFO:root:Epoch[7] Batch [100] Speed: 653.00 samples/sec accuracy=0.987327 INFO:root:Epoch[7] Batch [200] Speed: 650.61 samples/sec accuracy=0.988800 INFO:root:Epoch[7] Batch [300] Speed: 649.02 samples/sec accuracy=0.989100 INFO:root:Epoch[7] Batch [400] Speed: 644.93 samples/sec accuracy=0.990000 INFO:root:Epoch[7] Batch [500] Speed: 554.87 samples/sec accuracy=0.988700 INFO:root:Epoch[7] Train-accuracy=0.990202 INFO:root:Epoch[7] Time cost=94.743 INFO:root:Epoch[7] Validation-accuracy=0.987600 INFO:root:Epoch[8] Batch [100] Speed: 649.92 samples/sec accuracy=0.988812 INFO:root:Epoch[8] Batch [200] Speed: 654.07 samples/sec accuracy=0.990800 INFO:root:Epoch[8] Batch [300] Speed: 656.73 samples/sec accuracy=0.990700 INFO:root:Epoch[8] Batch [400] Speed: 653.70 samples/sec accuracy=0.990900 INFO:root:Epoch[8] Batch [500] Speed: 631.36 samples/sec accuracy=0.990200 INFO:root:Epoch[8] Train-accuracy=0.991616 INFO:root:Epoch[8] Time cost=92.349 INFO:root:Epoch[8] Validation-accuracy=0.988500 INFO:root:Epoch[9] Batch [100] Speed: 647.88 samples/sec accuracy=0.990792 INFO:root:Epoch[9] Batch [200] Speed: 635.89 samples/sec accuracy=0.991900 INFO:root:Epoch[9] Batch [300] Speed: 637.18 samples/sec accuracy=0.991700 INFO:root:Epoch[9] Batch [400] Speed: 640.23 samples/sec accuracy=0.992300 INFO:root:Epoch[9] Batch [500] Speed: 640.93 samples/sec accuracy=0.991900 INFO:root:Epoch[9] Train-accuracy=0.992828 INFO:root:Epoch[9] Time cost=93.533 INFO:root:Epoch[9] Validation-accuracy=0.988700