1. 程式人生 > >用RNN做MNIST分類

用RNN做MNIST分類

1.前言

RNN常用作NLP中,像圖片生成文字、自動生成古詩詞等。這篇文章用RNN做MNIST手寫數字識別,分類效果雖然沒有CNN效果好,但準確率也能夠達到96%。

2.環境

Mac os系統,python:3.5,Keras

3.程式碼實現

import numpy as np
np.random.seed(1337)  

from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN, Activation, Dense
from keras.optimizers import Adam

TIME_STEPS = 28     
INPUT_SIZE = 28    
BATCH_SIZE = 50
BATCH_INDEX = 0
OUTPUT_SIZE = 10
CELL_SIZE = 50
LR = 0.001

(X_train, y_train), (X_test, y_test) = mnist.load_data()

# data pre-processing
X_train = X_train.reshape(-1, 28, 28) / 255.      # normalize
X_test = X_test.reshape(-1, 28, 28) / 255.        # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)

# build RNN model
model = Sequential()

# RNN cell
model.add(SimpleRNN(
    batch_input_shape=(None, TIME_STEPS, INPUT_SIZE),       
    output_dim=CELL_SIZE,
    unroll=True,
))

# output layer
model.add(Dense(OUTPUT_SIZE))
model.add(Activation('softmax'))

# optimizer
adam = Adam(LR)
model.compile(optimizer=adam,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# training
for step in range(40001):
    X_batch = X_train[BATCH_INDEX: BATCH_INDEX+BATCH_SIZE, :, :]
    Y_batch = y_train[BATCH_INDEX: BATCH_INDEX+BATCH_SIZE, :]
    cost = model.train_on_batch(X_batch, Y_batch)
    BATCH_INDEX += BATCH_SIZE
    BATCH_INDEX = 0 if BATCH_INDEX >= X_train.shape[0] else BATCH_INDEX

    if step % 500 == 0:
        cost, accuracy = model.evaluate(X_test, y_test, batch_size=y_test.shape[0], verbose=False)
        print('test cost: ', cost, 'test accuracy: ', accuracy)

4.結果


相關推薦

RNNMNIST分類

1.前言 RNN常用作NLP中,像圖片生成文字、自動生成古詩詞等。這篇文章用RNN做MNIST手寫數字識別,分類效果雖然沒有CNN效果好,但準確率也能夠達到96%。 2.環境 Mac os系統,pyt

完整案例:caffe框架Alexnet分類的全部流程

一.資料的準備與預處理 資料的準備是非常重要的,我們現在準備做的是一個二分類任務,計劃選取男女圖片進行訓練與測試。 train:隨機選取300張男性圖片,300張女性圖片 val:隨機選取80張男

[深度學習框架] Keras上使用RNN進行mnist分類

import numpy as np np.random.seed(1337) # for reproducibility from keras.datasets import mnist from keras.utils import np_utils from ke

cnn行人分類

機器學習資料庫是關鍵,自己搜搜吧,規模太小訓練不出來,正樣本和負樣本。 訓練之前要處理訓練檔案,這個我在之前的python影象操作這篇博文裡寫過,並有完整程式碼。 也可以用我處理好的資料,稍後我會上傳 input_data.py """Functions for do

mysql遞迴查詢,mysql中從子類ID查詢所有父類(無限分類經常到)

由於mysql 不支援類似 oracle with ...connect的 遞迴查詢語法之前一直以為類似的查詢要麼用儲存過程要麼只能用程式寫遞迴查詢.現在發現原來一條sql語句也是可以搞定的先來看資料表的結構如下:id  name    parent_id&n

TensorFlow(五)——MNIST分類RNN

import input_data import tensorflow as tf import numpy as np mnist = input_data.read_data_sets('data/', one_hot=True) #設定訓練超引數 lr = 0.00

深度學習入門專案:keras構建CNN或LSTM對minist資料集簡單分類任務

深度學習入門專案:用keras構建CNN或LSTM或RNN對Minist資料集做簡單分類任務 參考keras中文文件 ——keras: 是一個高階神經網路庫,用 Python 語言寫成,可以執行在 TensorFlow 或者 Theano 之上(即以此為後端)。

gensimLDA實踐之文字分類

之前看LDA,一直沒搞懂到底作用是什麼,公式推導了一大堆,dirichlet分佈求了一堆倒數,卻沒有真正理解精髓在哪裡。 最近手上遇到了一個文字分類的問題,採用普通的VSM模型的時候,執行的太慢,後來查詢改進策略的時候,想起了LDA,因此把LDA重新拉回我的視

[TensorFlow深度學習入門]實戰十·RNN(LSTM)時間序列預測(曲線擬合)

[TensorFlow深度學習入門]實戰十·用RNN(LSTM)做時間序列預測(曲線擬合) %matplotlib inline import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import numpy as np import

[TensorFlow深度學習入門]實戰五·RNN(LSTM)手寫數字識別準確率98%+

參考部落格地址,修復了一個小Bug,收斂速度和準確率都略微提升。使用此模型在Kaggle比賽準確率98%+ import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import time import tensorflow as tf im

Sklearn判別分析(分類)

來自:http://cloga.info/python/2014/02/07/classify_use_Sklearn/#wat_e_12612920-6fe4-464e-a2b0-3b1f13c1a4f6_zss_ 載入資料集 這裡我使用pandas來載入資料集,資料集採用kaggle的titanic

分類和擴充套件有什麼區別?可以分別什麼?分類有哪些侷限性?分類的結構體裡面有哪些成員?

1、分類中原則上只能增加方法(能新增屬性的的原因只是通過runtime的objc_setAssociatedObject和objc_getAssociatedObject方法新增setter/getter方法); 2、擴充套件不僅可以增加方法,還可以增加例項變數(或者屬性),只是該例項變數預設是@p

Windows下Matlab載入caffemodel影象分類

1.編譯caffe的matlab介面 用到了happynear提供的caffe-windows-master,編譯caffe和matlab介面的過程看這裡。編譯好之後,caffe-windows-master\matlab\+caffe\private內的檔案如下: 如果

最新NLP庫Flair文字分類

介紹 文字分類是一種監督機器學習方法,用於將句子或文字文件歸類為一個或多個已定義好的類別。它是一個被廣泛應用的自然語言處理方法,在垃圾郵件過濾、情感分析、新聞稿件分類以及與許多其它業務相關的問題中發揮著重要作用。 目前絕大多數最先進的方法都依賴於一種被稱為文字嵌入的技術。它將文字轉換成高維空間

pytorch 利用lstmmnist手寫數字識別分類

程式碼如下,U我認為對於新手來說最重要的是學會rnn讀取資料的格式。 # -*- coding: utf-8 -*- """ Created on Tue Oct 9 08:53:25 2018 @author: www """ import sys sys.path

Windows下VS2013載入caffemodel影象分類

結果顯示在左上角,有英文和中文兩種標籤可選,如果顯示中文,需要使用Freetype庫,請自行百度。 #include <caffe/caffe.hpp> #include <opencv2/core/core.hpp> #include <opencv2/highgui/hig

單向RNN和雙向RNNmnist資料集上的分類實驗

RNN用於影象分類思路很奇特,不明覺厲,具體可以參考相關論文,rnn和birnn的實驗: #!/usr/bin/env python # -*- coding: utf-8 -*- # created by fhqplzj on 2017/06/19 下午10:28 fro

iptables軟路由實現共享上網

iptables 軟路由 共享上網 我們平時使用的大多數家用路由器都是通過NAT(Network Address Translation,網絡地址轉換)功能實現共享上網的,iptables是linux內核裏整合的一個ip信息包過濾系統,使用iptables配置nat就可以實現和家用路由器一樣的上網效果。

R語言nls非線性回歸以及函數模型的參數估計

nes 線性 -s legend 最小值 fun des and start 非線性回歸是在對變量的非線性關系有一定認識前提下,對非線性函數的參數進行最優化的過程,最優化後的參數會使得模型的RSS(殘差平方和)達到最小。在R語言中最為常用的非線性回歸建模函數是nls,下面以

郵箱網頁

text 軟件開發工具 kdt itl dia section ear 軟件 界面 https://zhidao.baidu.com/question/438466707.html qq郵箱郵件怎麽做成網頁發送??郵件當成網頁發送是不是很酷呢,看看我的經驗