1. 程式人生 > >TensorFlow入門(二)簡單前饋網路實現 mnist 分類

TensorFlow入門(二)簡單前饋網路實現 mnist 分類

歡迎轉載,但請務必註明原文出處及作者資訊。

兩層FC層做分類:MNIST

在本教程中,我們來實現一個非常簡單的兩層全連線網路來完成MNIST資料的分類問題。
輸入[-1,28*28], FC1 有 1024 個neurons, FC2 有 10 個neurons。這麼簡單的一個全連線網路,結果測試準確率達到了 0.98。還是非常棒的!!!

import numpy as np
import tensorflow as tf

# 設定按需使用GPU
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)

1. 匯入資料

# 用tensorflow 匯入資料
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
print 'training data shape ', mnist.train.images.shape
print 'training label shape ', mnist.train.labels.shape
training data shape  (55000, 784)
training label shape  (55000, 10)

2. 構建網路

# 權值初始化
def weight_variable(shape):
    # 用正態分佈來初始化權值
    initial = tf.truncated_normal(shape, stddev=0.1)
    return
tf.Variable(initial) def bias_variable(shape): # 本例中用relu啟用函式,所以用一個很小的正偏置較好 initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) # input_layer X_ = tf.placeholder(tf.float32, [None, 784]) y_ = tf.placeholder(tf.float32, [None, 10]) # FC1 W_fc1 = weight_variable([784, 1024]) b_fc1 = bias_variable([1024]) h_fc1 = tf.nn.relu(tf.matmul(X_, W_fc1) + b_fc1) # FC2 W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) y_pre = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)

3. 訓練和評估

# 1.損失函式:cross_entropy
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_pre))
# 2.優化函式:AdamOptimizer, 優化速度要比 GradientOptimizer 快很多
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

# 3.預測結果評估
# 預測值中最大值(1)即分類結果,是否等於原始標籤中的(1)的位置。argmax()取最大值所在的下標
correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.arg_max(y_, 1))  
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 開始執行
sess.run(tf.global_variables_initializer())
# 這大概迭代了不到 10 個 epoch, 訓練準確率已經達到了0.98
for i in range(5000):
    X_batch, y_batch = mnist.train.next_batch(batch_size=100)
    train_step.run(feed_dict={X_: X_batch, y_: y_batch})
    if (i+1) % 200 == 0:
        train_accuracy = accuracy.eval(feed_dict={X_: mnist.train.images, y_: mnist.train.labels})
        print "step %d, training acc %g" % (i+1, train_accuracy)
    if (i+1) % 1000 == 0:
        test_accuracy = accuracy.eval(feed_dict={X_: mnist.test.images, y_: mnist.test.labels})
        print "= " * 10, "step %d, testing acc %g" % (i+1, test_accuracy)
step 200, training acc 0.937364
step 400, training acc 0.965818
step 600, training acc 0.973364
step 800, training acc 0.977709
step 1000, training acc 0.981528
= = = = = = = = = =  step 1000, testing acc 0.9688
step 1200, training acc 0.988437
step 1400, training acc 0.988728
step 1600, training acc 0.987491
step 1800, training acc 0.993873
step 2000, training acc 0.992527
= = = = = = = = = =  step 2000, testing acc 0.9789
step 2200, training acc 0.995309
step 2400, training acc 0.995455
step 2600, training acc 0.9952
step 2800, training acc 0.996073
step 3000, training acc 0.9964
= = = = = = = = = =  step 3000, testing acc 0.9778
step 3200, training acc 0.996709
step 3400, training acc 0.998109
step 3600, training acc 0.997455
step 3800, training acc 0.995055
step 4000, training acc 0.997291
= = = = = = = = = =  step 4000, testing acc 0.9808
step 4200, training acc 0.997746
step 4400, training acc 0.996073
step 4600, training acc 0.998564
step 4800, training acc 0.997946
step 5000, training acc 0.998673
= = = = = = = = = =  step 5000, testing acc 0.98

相關推薦

TensorFlow入門簡單網路實現 mnist 分類

歡迎轉載,但請務必註明原文出處及作者資訊。 兩層FC層做分類:MNIST 在本教程中,我們來實現一個非常簡單的兩層全連線網路來完成MNIST資料的分類問題。 輸入[-1,28*28],

TensorFlow入門 簡單網路實現 mnist 分類

import tensorflow as tf # 設定按需使用GPU config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=

spring-data-jpa快速入門——簡單查詢

ref spa data mail domain event cif open 寫實 一、方法名解析   1.引言     回顧HelloWorld項目中的dao接口 public interface GirlRepository extends JpaRepos

tensorflow入門 變數的定義、初始化、值的檢視

1、常量 constant是TensorFlow的常量節點,通過constant方法建立,其是計算圖(Computational Graph)中的起始節點,是傳入資料; import tensorflow as tf sess = tf.Interact

AI聖經-深度學習-讀書筆記-深度網路

深度前饋網路(DFN) 0 簡介 (1)DFN:深度前饋網路,或前饋神經網路(FFN),或多層感知機(MLP) (2)目標 近似某個函式 f∗f∗。例如,定義一個對映y=f(x;θ)y=f(x;θ),並且學習θθ的值,使它能夠得到最佳的函式近似。

動態規劃入門DP 基本思想 具體實現 經典題目 POJ1088

(一) POJ1088,動態規劃的入門級題目。嘿嘿,連題目描述都是難得一見的中文。 題目分析: 求最長的滑雪路徑,關鍵是確定起點,即從哪開始滑。 不妨設以( i, j )為起點,現在求滑行的最長路徑。 首先,( i, j )能滑向的無非就是它四周比它低的點。到底滑向哪個點?

《deep learning》學習筆記6——深度網路

6.1 例項:學習 XOR 通過學習一個表示來解決 XOR 問題。圖上的粗體數字標明瞭學得的函式必須在每個點輸出的值。(左) 直接應用於原始輸入的線性模型不能實現 XOR 函式。當 x 1 = 0 時,模型的輸出必須隨著 x 2 的增大而增大。當 x

Asp.Net Core WebAPI入門整理簡單示例

序列 open exc tor pda template ssa net found 一、Core WebAPI中的序列化 使用的是Newtonsoft.Json,自定義全局配置處理: // This method gets called by the runtime.

Spring Boot簡單入門

上篇演示了Spring Boot最簡單的一個例項,這篇將整合mysql和mybatis實現簡單的登入功能。 以前我們寫java程式碼,都要自己手動一個一個去建實體類,寫方法,這種瑣碎的事情不但無聊且浪費時間,這裡介紹一個好用的工具,Mybatis-Generator。它可以自動生成Dao、Mod

深度學習入門——TensorFlow介紹

TensorFlow     1.使用圖 (graph) 來表示計算任務.     2.在被稱之為 會話 (Session) 的上下文 (context) 中執行圖.     3.使用 tensor 表示資料.     4通過 變數 (Variable) 維護狀態.

簡單搜尋入門:二分答案 HDU 5248

二分練習的第二部分——二分答案的查詢 Description 給定序列A={A1,A2,…,An}, 要求改變序列A中的某些元素,形成一個嚴格單調的序列B(嚴格單調的定義為:Bi< Bi+1, 1≤i< N)。 我們定義從序列A到

TensorFlow初學者入門——MNIST機器學習入門

本人學習TensorFlow中的一些學習筆記和感悟,僅供學習參考,有疑問的地方可以一起交流討論,持續更新中。 本文學習地址為:TensorFlow官方文件,在此基礎上加入了自己的學習筆記和理解。 文章是建立在有一定的深度學習基礎之上的,建議有一定理論基礎之後再同步學習。 1.準備MNIS

效能測試入門:做個最簡單的效能測試

之前在《效能測試中的各項指標告訴我們什麼》簡單介紹了一些基本的效能指標的含義,明確了我們效能測試的目標是在保證請求成功率及不超過目標請求時間的情況下,找出我們系統的最大併發量。在這篇文章中我們做些實踐,以程式設計師小張的視角來做一次效能測試。 做個最簡單的

Vue簡單入門

click 處理 tex sage .com 數據 工作 spa -c 根據上一節搭建的hello-world工程(包含Router),用Webstorm打開,我們先運行一下工程。 界面如下 .. 我將在About裏面介紹一下Vue的相關內容。 打開Abo

swing入門教程 簡單的swing小部件

—— 就像所有的“x 入門”教程一樣,本教程也包含必不可少的 HelloWorld 演示。但這個示例不僅對觀察 Swing 應用程式如何工作有用,還對確保設定正確很有用。一旦使這個簡單的應用程式能

數字IC低功耗設計入門——功耗的分析

layout 變化 監視 merge obj source divide 傳播 總結   前面學習了進行低功耗的目的個功耗的構成,今天就來分享一下功耗的分析。由於是面向數字IC前端設計的學習,所以這裏的功耗分析是基於DC中的power compiler工具;更精確的功耗分析

Linux入門

man linux終端 linux發行版本 linux文件系統初步 google高級用法 Linux常用的基礎命令1.發行版本2.CISC、RISC3.編譯和反編譯(GPL、LGPL、BSD)4.程序包管理5.文件系統初步終端設備虛擬終端圖形終端串行終端偽終端Linux的哲學思想6.開源協

Docker入門

docker安裝 docker基礎命令 一、Docker相關概念1.Docker: namespace,cgroup: 解決方案: lxc,openvz lxc:linux containers docker最初就是lxc的封裝版本。 docker engine/docker server:輸

vue-cli入門——項目結構

常用 作用 寫到 www. 簡單的 端口 server 標簽 emp 前言 在上一篇項目搭建文章中,我們已經下載安裝了node環境以及vue-cli,並且已經成功構建了一個vue-cli項目,那麽接下來,我們來梳理一下vue-cli項目的結構。 總體框架 一個vue-c

log4j2使用入門——與不同日誌框架的適配

一個 slf4 core log4j 說明 不同 activemq 進行 -a 在上方中已經指出log4j2可以與不同的日誌框架進行適配,這裏舉一些實際應用進行說明: 1.比如我們在項目中使用了log4j2作為日誌器,使用了log4j-api2.6.2.jar和log4j