1. 程式人生 > >tensorflow學習(3):解讀mnist_experts例子,訓練儲存模型並tensorboard視覺化

tensorflow學習(3):解讀mnist_experts例子,訓練儲存模型並tensorboard視覺化

前言

官網的mnist例子講解的很清楚了,要想深入理解這個網路結構到底幹了什麼,還是需要好好入門一下深度學習的基礎知識。好好看看Michael Nielsen大神寫的這本書:Neural Networks and Deep Learning,中文版下載地址:中文版pdf

本文講解mnist_experts例子並結合tensorflow的常用操作(tensorboard視覺化和Saver儲存模型)來改寫mnist例子,豐富例子的功能。

開始之前,你應該瞭解一下mnist資料集,詳見部落格:mnist解讀

正文

一、首先來看一眼例子的網路結構和改寫後的原始碼

網路結構

# encoding=utf-8
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1
, shape=shape) return tf.Variable(initial) myGraph = tf.Graph() with myGraph.as_default(): with tf.name_scope('inputsAndLabels'): x_raw = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) with tf.name_scope('hidden1'): x = tf.reshape(x_raw, shape=[-1
,28,28,1]) W_conv1 = weight_variable([5,5,1,32]) b_conv1 = bias_variable([32]) l_conv1 = tf.nn.relu(tf.nn.conv2d(x,W_conv1, strides=[1,1,1,1],padding='SAME') + b_conv1) l_pool1 = tf.nn.max_pool(l_conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') tf.summary.image('x_input',x,max_outputs=10) tf.summary.histogram('W_con1',W_conv1) tf.summary.histogram('b_con1',b_conv1) with tf.name_scope('hidden2'): W_conv2 = weight_variable([5,5,32,64]) b_conv2 = bias_variable([64]) l_conv2 = tf.nn.relu(tf.nn.conv2d(l_pool1, W_conv2, strides=[1,1,1,1], padding='SAME')+b_conv2) l_pool2 = tf.nn.max_pool(l_conv2, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME') tf.summary.histogram('W_con2', W_conv2) tf.summary.histogram('b_con2', b_conv2) with tf.name_scope('fc1'): W_fc1 = weight_variable([64*7*7, 1024]) b_fc1 = bias_variable([1024]) l_pool2_flat = tf.reshape(l_pool2, [-1, 64*7*7]) l_fc1 = tf.nn.relu(tf.matmul(l_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder(tf.float32) l_fc1_drop = tf.nn.dropout(l_fc1, keep_prob) tf.summary.histogram('W_fc1', W_fc1) tf.summary.histogram('b_fc1', b_fc1) with tf.name_scope('fc2'): W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) y_conv = tf.matmul(l_fc1_drop, W_fc2) + b_fc2 tf.summary.histogram('W_fc1', W_fc1) tf.summary.histogram('b_fc1', b_fc1) with tf.name_scope('train'): cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_conv, labels=y)) train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('loss', cross_entropy) tf.summary.scalar('accuracy', accuracy) with tf.Session(graph=myGraph) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() merged = tf.summary.merge_all() summary_writer = tf.summary.FileWriter('./mnistEven/', graph=sess.graph) for i in range(10001): batch = mnist.train.next_batch(50) sess.run(train_step,feed_dict={x_raw:batch[0], y:batch[1], keep_prob:0.5}) if i%100 == 0: train_accuracy = accuracy.eval(feed_dict={x_raw:batch[0], y:batch[1], keep_prob:1.0}) print('step %d training accuracy:%g'%(i, train_accuracy)) summary = sess.run(merged,feed_dict={x_raw:batch[0], y:batch[1], keep_prob:1.0}) summary_writer.add_summary(summary,i) test_accuracy = accuracy.eval(feed_dict={x_raw:mnist.test.images, y:mnist.test.labels, keep_prob:1.0}) print('test accuracy:%g' %test_accuracy) saver.save(sess,save_path='./model/mnistmodel',global_step=1)

二、網路結構
由上面的結構圖可以看出,輸入影象先經過2個conv-pooling層,然後進入一個全連線層,最後經過softmax層輸出結果。這些層的作用都在前言中推薦的書中講的很詳細。

三、tensorboard視覺化
tensorflow提供了一套視覺化的工具,就是tensorboard,它可以幫你繪製計算圖,還可以記錄訓練過程中引數的變化情況並繪製成變化圖,其使用會用到tf.summary.xxx這些API。首先在構建計算圖的時候一個變數一個變數蒐集,構建完後將這些變數合併然後在訓練過程中寫入到event檔案中。

為了計算圖看著簡潔,一般將每一層定義到一個名稱空間中,如下所示:

with tf.name_scope('hidden1'):
        x = tf.reshape(x_raw, shape=[-1,28,28,1])
        W_conv1 = weight_variable([5,5,1,32])
        b_conv1 = bias_variable([32])

那麼,W/b變數被名稱空間包含後,繪製出的計算圖會有一個名為“hidden1”的方塊,點開方塊後能看到裡面的變數。

然後記錄變數,不同的變數有不同的記錄方法。

1,對於損失函式和準確率這樣的單值變數,用tf.summary.scalar()來蒐集

cross_entropy = tf.reduce_mean(xxx(logits=y_conv, labels=y))
train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy)

tf.summary.scalar('loss', cross_entropy)
tf.summary.scalar('accuracy', accuracy)

2,對於W和b這樣的高維變數,用tf.summary.histogram()來蒐集

    with tf.name_scope('fc2'):
        W_fc2 = weight_variable([1024, 10])
        b_fc2 = bias_variable([10])
        y_conv = tf.matmul(l_fc1_drop, W_fc2) + b_fc2

        tf.summary.histogram('W_fc1', W_fc1)
        tf.summary.histogram('b_fc1', b_fc1)

3,常用的是上面2個,也可以用tf.summary.image()來蒐集出入圖片,這樣能看到輸入的圖片是什麼樣的

    with tf.name_scope('hidden1'):
        x = tf.reshape(x_raw, shape=[-1,28,28,1])
        tf.summary.image('x_input',x,max_outputs=10) 

蒐集完變數後就可以在訓練過程中合併並寫入到event檔案啦

with tf.Session(graph=myGraph) as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    merged = tf.summary.merge_all() #合併
    summary_writer = tf.summary.FileWriter('./mnistEven/', graph=sess.graph) #檔案寫路徑

    for i in range(10001):
        batch = mnist.train.next_batch(50)
        sess.run(train_step,feed_dict={x_raw:batch[0], y:batch[1], keep_prob:0.5})
        if i%100 == 0:
            train_accuracy = accuracy.eval(feed_dict={x_raw:batch[0], y:batch[1], keep_prob:1.0})
            print('step %d training accuracy:%g'%(i, train_accuracy))

            summary=sess.run(merged,feed_dict={x_raw:batch[0], y:batch[1], keep_prob:1.0})#計算變數
            summary_writer.add_summary(summary,i) # 每100步,將所有蒐集的寫檔案

訓練完成後,執行下面的命令

$ tensorboard --logdir=path/to/log-directory

再通過瀏覽器輸入地址:”localhost:6006”,來看tensorboard。

下面的圖就是計算圖和蒐集到的引數:

計算圖

這裡寫圖片描述

這裡寫圖片描述

總之,tensorboard可以視覺化訓練過程。

四,Saver儲存訓練模型

儲存模型比較簡單,首先在訓練之前定義一個saver,訓練結束後呼叫saver.save()即可。

with tf.Session(graph=myGraph) as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    for i in range(10001):
        #迴圈裡面是訓練的過程

    saver.save(sess,save_path='./model/mnistmodel',global_step=1)

儲存模型其實是儲存了模型中的引數,主要是權重W和偏置b。儲存到指定的目錄下的是二進位制檔案,我們可以重構了網路之後,直接從這些檔案中restore引數,就可以直接使用啦。

下一章詳細介紹如何restore模型並建立一個識別引擎,用來識別一張帶有數字的圖片。點選進入部落格