MXNet | 手寫字MNIST識別比賽
阿新 • • 發佈:2019-01-26
MNIST手寫字圖片資料集由Yann LeCun建立,每條資料表示28*28畫素的圖片。它已經是用於衡量分類器在簡單圖片作為輸入的標準資料集。神經網路是對於圖片分類任務來說是強大的模型。這是一個在kaggle長期舉辦的比賽資料集。
讀取資料集,這裡用readr中的函式read_csv,讀取速度快高效
setwd("F:\\迅雷下載\\mnist")
require(mxnet)
library(readr)
train <- read_csv('train.csv')
test <- read_csv('test.csv')
資料集:訓練集和測試集
> train <- data .matrix(train)
> test <- data.matrix(test)
> train.x <- train[,-1]
> train.y <- train[,1]
> train <- data.matrix(train)
> test <- data.matrix(test)
> train.x <- train[,-1]
> train.y <- train[,1]
資料放縮到[0,1]
> train.x <- t(train.x/255)
> test <- t(test/255)
標籤
> table(train.y)
train.y
0 1 2 3 4 5 6 7 8 9
4132 4684 4177 4351 4072 3795 4137 4401 4063 4188
資料集還是比較平衡,不同之間的差異不大
構建網路
#定義
> data <- mx.symbol.Variable("data")
#第一層,全連線,隱藏節點128個
> fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)
#啟用函式為relu
> act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")
#第二層,隱藏節點為64個
> fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)
#啟用函式為relu
> act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")
#第三層,隱藏節點為10個
> fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)
##啟用函式為sm,即softmax
> softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")
訓練,採用cpu的方式
#cpu
>devices <- mx.cpu()
#隨機種子
>mx.set.seed(0)
#模型
>model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=devices, num.round=10, array.batch.size=100,
learning.rate=0.07, momentum=0.9, eval.metric=mx.metric.accuracy,
initializer=mx.init.uniform(0.07),
epoch.end.callback=mx.callback.log.train.metric(100))
Start training with 1 devices
[1] Train-accuracy=0.859832935560859
[2] Train-accuracy=0.957666666666668
[3] Train-accuracy=0.971023809523813
[4] Train-accuracy=0.977714285714289
[5] Train-accuracy=0.981571428571432
[6] Train-accuracy=0.986309523809527
[7] Train-accuracy=0.988952380952383
[8] Train-accuracy=0.990880952380956
[9] Train-accuracy=0.992142857142861
[10] Train-accuracy=0.991095238095241
訓練的精度為99.10%
預測
> preds <- predict(model, test)
> dim(preds)
[1] 10 28000
> pred.label <- max.col(t(preds)) - 1
預測後的類別
> table(pred.label)
pred.label
0 1 2 3 4 5 6 7 8 9
2816 3216 2753 2791 2709 2544 2762 2836 2780 2793
得到提交的資料集ID和label
submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)
submission.csv檔案在你的工作目錄下,然後去kaggle提交下。
結果顯示
下面給出完整的程式碼:
setwd("F:\\迅雷下載\\mnist")
require(mxnet)
library(readr)
train <- read_csv('train.csv')
test <- read_csv('test.csv')
train <- data.matrix(train)
test <- data.matrix(test)
train.x <- train[,-1]
train.y <- train[,1]
# 資料放縮到[0,1]
train.x <- t(train.x/255)
test <- t(test/255)
table(train.y)
#構建網路
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)
act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")
fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)
act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")
fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)
softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")
########訓練
##cpu
devices <- mx.cpu()
mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=devices, num.round=10, array.batch.size=100,
learning.rate=0.07, momentum=0.9, eval.metric=mx.metric.accuracy,
initializer=mx.init.uniform(0.07),
epoch.end.callback=mx.callback.log.train.metric(100))
#預測
preds <- predict(model, test)
dim(preds)
pred.label <- max.col(t(preds)) - 1
table(pred.label)
submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)