1. 程式人生 > >TensorFlow實現多層LSTM識別MNIST手寫字,多層LSTM下state和output的關係

TensorFlow實現多層LSTM識別MNIST手寫字,多層LSTM下state和output的關係

其他內容

輸入格式:batch_size*784改成batch_size*28*28,28個序列,內容是一行的28個灰度數值。

讓神經網路逐行掃描一個手寫字型圖案,總結各行特徵,通過時間序列串聯起來,最終得出結論。

網路定義:單獨定義一個獲取單元的函式,便於在MultiRNNCell中呼叫,建立多層LSTM網路

def get_a_cell(i):
    lstm_cell =rnn.BasicLSTMCell(num_units=HIDDEN_CELL, forget_bias = 1.0, state_is_tuple = True, name = 'layer_%s'%i)
    print(type(lstm_cell))
    dropout_wrapped = rnn.DropoutWrapper(cell = lstm_cell, input_keep_prob = 1.0, output_keep_prob = keep_prob)
    return dropout_wrapped

multi_lstm = rnn.MultiRNNCell(cells = [get_a_cell(i) for i in range(LSTM_LAYER)],
                              state_is_tuple=True)#tf.nn.rnn_cell.MultiRNNCell

多層RNN下state和單層RNN有所不同,多了些細節,每一層都是一個cell,每一個cell都有自己的state,每一層都對應一個LSTMStateTuple(本例是分類預測,所以只用到最後一層的輸出,但是不代表其他情況不需要使用中間層的狀態)。

cell之間是串聯的,-1是最後一層的state,等價於單層下的output,我這裡建了三層,所以-1和2相等:


outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
print('state:',state)
print('state[0]:',state[0])#layer 0's LSTMStateTuple
print('state[1]:',state[1])#layer 1's LSTMStateTuple
print('state[2]:',state[2])#layer 2's LSTMStateTuple
print('state[-1]:',state[-1])#layer 2's LSTMStateTuple

state: (LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>))
state[0]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>)
state[1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>)
state[2]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)
state[-1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)

下邊是outputs和states的對比:outputs對應state_2,又因為這裡做的是型別預測,是Nvs1模型,且time_major是False,第0維是batch,要取時間序列的最後一個輸出,用[:,-1,:],可以看到,是全相等的。


outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
h_state_0 = state[0][1]
h_state_1 = state[1][1]
h_state = state[-1][1]
h_state_2 = h_state



        _, loss_,outputs_, state_, h_state_0_, h_state_1_, h_state_2_ = \
            sess.run([train_op, cross_entropy,outputs, state, h_state_0, h_state_1, h_state_2], {tf_x:x, tf_y:y, keep_prob:1.0})


        print('h_state_2_ == outputs_[:,-1,:]:', h_state_2_ == outputs_[:,-1,:])


h_state_2_ == outputs_[:,-1,:]: [[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]

最後處理一下輸出:LSTM的介面為了使用方便,輸入輸出是等維度的,不可設定,隱藏單元這裡設定的256,需要做一個轉換,轉換為10維輸出,最終對手寫數字進行分類預測。

#prediction and loss
W = tf.Variable(initial_value = tf.truncated_normal([HIDDEN_CELL, CLASS_NUM], stddev = 0.1 ), dtype = tf.float32)
print(W)
b = tf.Variable(initial_value = tf.constant(0.1, shape = [CLASS_NUM]), dtype = tf.float32)
predictions = tf.nn.softmax(tf.matmul(h_state, W) + b)
#sum   -ylogy^
cross_entropy = -tf.reduce_sum(tf_y * tf.log(predictions))

完整程式碼:

相關推薦

TensorFlow實現LSTM識別MNIST寫字LSTMstateoutput關係

其他內容 輸入格式:batch_size*784改成batch_size*28*28,28個序列,內容是一行的28個灰度數值。 讓神經網路逐行掃描一個手寫字型圖案,總結各行特徵,通過時間序列串聯起來,最終得出結論。 網路定義:單獨定義一個獲取單元的函式,便於在M

tensorflow實現感知機進行寫字識別

logits=multilayer_perceptron(X) #使用交叉熵損失 loss_op=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=Y))

Tensorflow #2 深度學習-RNN LSTMMNIST識別Demo

Tensorflow #1 祖傳例子 MNIST 手寫識別 Tensorflow自帶的那個MNIST任務其實挺好用的,之前使用最簡單的方法去做,記得正確率應該是92%附近? 其實那個例子是用來熟悉Tensorflow的,算是一個對Tensorflow的熟悉

深度學習-tensorflow學習筆記(2)-MNIST寫字識別

image utf-8 詳情 識別 標簽 ins AI tor 第一個           深度學習-tensorflow學習筆記(2)-MNIST手寫字體識別   這是tf入門的第一個例子。minst應該是內置的數據集。   前置知識在學習筆記(1)裏面講過了   這裏直

TensorFlow學習筆記(1)—— MNIST識別

1、初步學習 資料處理 xs:60000張圖片,28*28大小,將所有畫素點按一列排列,資料集變為了[60000, 784]的二維矩陣。 ys:60000張圖片,每個圖片有一個標籤標識圖片中數字,採用one-hot向量,資料集變為[60000, 10]的二維矩陣。 softm

mnist寫字識別tensorflow與tflearn對比

一、mnist機器學習入門 MNIST是一個入門級的計算機視覺資料集,它包含各種手寫數字圖片: 它也包含每一張圖片對應的標籤,告訴我們這個是數字幾。比如,上面這四張圖片的標籤分別是5,0,4,1。 在此教程中,我們將訓練一個機器學習模型用於預測圖片裡面的數字。我們的

Tensorflow案例5:CNN演算法-Mnist寫數字識別

學習目標 目標 應用tf.nn.conv2d實現卷積計算 應用tf.nn.relu實現啟用函式計算 應用tf.nn.max_pool實現池化層的計算 應用卷積神經網路實現影象分類識別 應用

Tensorflow | MNIST寫字識別

這次對最近學習tensorflow的總結,以理解MNIST手寫字識別案例為例來說明 0、資料解釋 資料為圖片,每個圖片是28畫素*28畫素,帶有標籤,類似於X和Y,X為28畫素*28畫素的資料,Y為該圖片的真實數字,即標籤。 1、資料的處理

深度學習筆記——TensorFlow學習筆記(三)使用TensorFlow實現的神經網路進行MNIST手寫體數字識別

本文是TensorFlow學習的第三部分,參考的是《TensorFlow實戰Google深度學習框架》一書,這部分講述的是使用TensorFlow實現的神經網路進行MNIST手寫體數字識別一個例項。 這個例項將第二部分講述的啟用函式、損失函式、優化演算法、正則化等都運用上了

pytorch 利用lstmmnist寫數字識別分類

程式碼如下,U我認為對於新手來說最重要的是學會rnn讀取資料的格式。 # -*- coding: utf-8 -*- """ Created on Tue Oct 9 08:53:25 2018 @author: www """ import sys sys.path

Keras中將LSTM用於mnist寫數字識別

import keras from keras.layers import LSTM from keras.layers import Dense, Activation from keras.datasets import mnist from keras.models

運用tensorflow全連線神經網路進行MNIST寫數字影象識別

本文記錄tensorflow搭建簡單神經網路,並進行模組化處理,目的在於總結並提取簡單神經網路搭建的基本思想和方法,提煉核心結構和元素,從而能夠移植到日後深入學習中去。 1 模組提煉 1.1 template_forward.py

人工智能 tensorflow框架-->MNIST寫字符數據集 06

推廣 x文件 數據集 2.4 mage esx cnblogs -i 向量空間 1.下載MNIST數據集: 2.1數據集分成兩部分:60000行的訓練集 trainxxx (包含手寫數字的圖片imagexxx 和 手寫數字對應的標簽labelxxx)

matlab練習程序(神經網絡識別mnist寫數據集)

sum else ref rate 標準 個數 權重矩陣 ros learn 記得上次練習了神經網絡分類,不過當時應該有些地方寫的還是不對。 這次用神經網絡識別mnist手寫數據集,主要參考了深度學習工具包的一些代碼。 mnist數據集訓練數據一共有28*28*6000

matlab練習程式(神經網路識別mnist寫資料集)

記得上次練習了神經網路分類,不過當時應該有些地方寫的還是不對。 這次用神經網路識別mnist手寫資料集,主要參考了深度學習工具包的一些程式碼。 mnist資料集訓練資料一共有28*28*60000個畫素,標籤有60000個。 測試資料一共有28*28*10000個,標籤10000個。 這裡神經網路輸入

LSTMMNIST寫資料集上做分類(程式碼中尺寸變換細節)

RNN和LSTM學了有一段時間了,主要都是看部落格瞭解原理,最近在研究SLSTM,在對SLSTM進行實現的時候遇到了困難,想說先比較一下二者的理論實現,這才發現自己對於LSTM內部的輸入輸出格式、輸出細節等等都不是非常清楚,藉此機會梳理一下,供後續研究使用。 下面程式碼來自

機器學習筆記(十二):TensorFlow實現四(影象識別與卷積神經網路)

1 - 卷積神經網路常用結構 1.1 - 卷積層 我們先來介紹卷積層的結構以及其前向傳播的演算法。 一個卷積層模組,包含以下幾個子模組: 使用0擴充邊界(padding) 卷積視窗過濾器(filter) 前向卷積 反向卷積(可選) 1.1.2 - 邊界填充

Keras 入門課1 -- 用MLP識別mnist寫字元

mlp就是multilayer perceptron,多層感知機。資料集用的是經典的mnist,數字分類問題。 首先匯入keras的各種模組 keras.datasets 裡面包含了多種常用資料集,如mnist,cifar10等等,可以實現自動下載和解析等等。

tensorflow實現人臉檢測及識別(簡單版)

本教程主要是對人臉檢測及識別python實現系列 及碉堡了!程式設計師用深度學習寫了個老闆探測器(付原始碼) 的實現。主要實現的功能是用網路攝像頭自動識別在工位通道走過的人臉,如果確認是老闆的話,就用一張圖片覆蓋到整個螢幕上。雖然原教程已經寫的很好,但是我們在實現的時候仍然踩

MATLAB自動識別MNIST寫數字資料庫

1.MNIST手寫數字資料庫 資料庫由Google實驗室的Corinna Cortes和紐約大學柯朗研究所的Yann LeCun建有一個手寫數字資料庫,訓練庫有60,000張手寫數字影象,測試庫有10,000張。下載地址: 2 程式碼 This release i