1. 程式人生 > >tensorflow 12:雙隱層+softmax迴歸實現mnist圖片識別之二

tensorflow 12:雙隱層+softmax迴歸實現mnist圖片識別之二

概述

tensorflow的自帶例程用兩個檔案演示了“全連線層+softmax迴歸”實現mnist圖片識別的功能。一個檔案是mnist.py,在之前一篇文章《tensorflow 11:雙隱層+softmax迴歸實現mnist圖片識別》已經介紹過了。不過mnist.py側重搭建計算圖,沒有呼叫過程。

本文講解fully_connected_feed.py這個檔案,主要講解呼叫過程相關的知識點。

fully_connected_feed.py速覽

fully_connected_feed.py所做的工作除了正常的引數解析、建立圖、讀取資料、迴圈訓練,還包括了儲存圖結構、資訊彙總、儲存檢查點檔案。

引數解析

fully_connected_feed.py把模型的(超)引數全部解析到全域性變數FLAGS裡面,然後其它地方用FLAGS獲取使用者傳參。

引數解析相關知識請看之前的一篇博文《tensorflow 9. 引數解析和經典入口函式tf.app.run》

狀態視覺化

tensorflow提供了一個工具視覺化訓練過程,該工具叫tensorboard。為tensorboard提供的資訊一般分兩類:圖結構和即時資訊。

儲存圖結構

為了儲存圖結構,在計算圖簡歷完畢以後,只要例項化一個SummaryWriter就行了。

# 例項化一個 SummaryWriter 輸出 summaries 和 Graph.
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

只要呼叫一次summary_writer.flush()或者summary_writer.close(), 計算圖的結構就會寫入tensorboard的日誌檔案中。 如下: 在這裡插入圖片描述

name_scope

可以看到上圖的結構是分層顯示的,雙擊一個模組,會顯示模組內的詳細資訊。雙擊隱層1之後效果如下: 在這裡插入圖片描述

這種分層的效果是用name_scope實現的:

  with tf.name_scope('hidden1'):
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=
1.0 / math.sqrt(float(IMAGE_PIXELS))), name='weights') biases = tf.Variable(tf.zeros([hidden1_units]), name='biases') hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)

在結構圖上,所有with tf.name_scope(‘hidden1’)模組內的節點都會包含在‘hidden1’方塊內。可以巢狀使用name_scope。

即時資訊(彙總)

除了圖結構,我們還關心訓練過程的loss值等即時資訊。也可以把這些資訊寫入tensorboard日誌檔案,然後從網頁觀察。彙總即時資訊需要以下幾步。

構建計算圖時

為了儲存即時資訊,在構建計算圖時需要構建即時資訊的節點,並在最優化時指定global_step。比如loss資訊的新增方式如下:

def training(loss, learning_rate):
  # 為儲存loss的值新增一個標量彙總(scalar summary).
  tf.summary.scalar('loss', loss)
  # 根據給定的學習率建立梯度下降優化器
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  # 建立一個變數來跟蹤global step.
  global_step = tf.Variable(0, name='global_step', trainable=False)
  # 在訓練節點,使用optimizer將梯度下降法應用到可調引數上來最小化損失
  # (同時不斷增加 global step 計數器) .
  train_op = optimizer.minimize(loss=loss,global_step=global_step)
  return train_op

除了標量,還可以彙總直方圖、圖片、音訊等資料。

構建計算圖之後

所有的即時資料(在這裡只有一個)都要在圖表構建階段合併至一個操作(op)中。

summary_op = tf.merge_all_summaries()

session建立之後

在建立好會話(session)之後,可以例項化一個tf.train.SummaryWriter,用於寫入包含了圖表本身和即時資料具體值的事件檔案。

# 例項化一個 SummaryWriter 輸出 summaries 和 Graph.
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

執行summary_op

最後,每次執行summary_op時,都會往事件檔案中寫入最新的即時資料,函式的輸出會傳入SummaryWriter的add_summary()函式,並指定當前的global_step。

if step % 100 == 0:
    # 更新事件檔案.還是呼叫sess.run()方法
    summary_str = sess.run(summary, feed_dict=feed_dict)
    summary_writer.add_summary(summary_str, global_step=step)
    summary_writer.flush()

視覺化總結

借用《TensorFlow實現全連線fully_connected_feed.py》中的一張圖: 在這裡插入圖片描述

即時資訊的彙總效果如下:

在這裡插入圖片描述

儲存檢查點(checkpoint)

為了得到可以用來後續恢復模型以進一步訓練或評估的檢查點檔案(checkpoint file),我們例項化一個tf.train.Saver。這一步不依賴計算圖的建立狀態。

saver = tf.train.Saver(max_to_keep=5)

saver預設只保留最近的5個ckpt檔案,可以通過max_to_keep來改變。

在訓練迴圈中,將定期(隔一定步數)呼叫saver.save()方法,向訓練資料夾中寫入包含了當前所有可訓練變數值得檢查點檔案。

      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)

儲存ckpt目錄下的檔案如下:

在這裡插入圖片描述

這樣,我們以後(暫停或斷電後)就可以使用saver.restore()方法,過載模型的引數,繼續訓練。

saver.restore(sess, FLAGS.train_dir)

完整程式碼

上面說了這麼多,還沒上完整的fully_connected_feed.py程式碼。這個是我修改註釋過的:

"""Trains and Evaluates the MNIST network using a feed dictionary."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=missing-docstring
import argparse
import os
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
import mnist_softmax_2hidden

# 全域性變數,用來存放基本的模型(超)引數.
FLAGS = None

# 產生 placeholder variables 來表達輸入張量
def placeholder_inputs(batch_size):
  """Generate placeholder variables to represent the input tensors.

  These placeholders are used as inputs by the rest of the model building
  code and will be fed from the downloaded data in the .run() loop, below.

  Args:
    batch_size: The batch size will be baked into both placeholders.

  Returns:
    images_placeholder: Images placeholder.
    labels_placeholder: Labels placeholder.
  """
  # Note that the shapes of the placeholders match the shapes of the full
  # image and label tensors, except the first dimension is now batch_size
  # rather than the full size of the train or test data sets.
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         mnist_softmax_2hidden.IMAGE_PIXELS))
  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
  return images_placeholder, labels_placeholder

# 填充 feed_dict 用於一個指定的訓練階段(given training step)
def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.

  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict


def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  """執行一個回合(one epoch)的評估過程.

  Args:
    sess: The session in which the model has been trained.
    eval_correct: The Tensor that returns the number of correct predictions.
    images_placeholder: The images placeholder.
    labels_placeholder: The labels placeholder.
    data_set: The set of images and labels to evaluate, from
      input_data.read_data_sets().
  """
  # And run one epoch of eval.
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = float(true_count) / num_examples
  print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))


def run_training():
  """對MNIST網路訓練指定的次數(一次訓練稱為一個training step)"""
  # 獲取用於訓練,驗證和測試的影象資料以及類別標籤集合
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

  # 告訴TensorFlow,模型將會被構建在預設的Graph上.
  with tf.Graph().as_default():
    # 為影象特徵向量資料和類標籤資料建立輸入佔位符
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # 從前向推斷模型中構建用於預測的計算圖
    logits = mnist_softmax_2hidden.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # 為計算圖新增計算損失的節點.
    loss = mnist_softmax_2hidden.loss(logits, labels_placeholder)

    # 為計算圖新增計算和應用梯度的訓練節點
    train_op = mnist_softmax_2hidden.training(loss, FLAGS.learning_rate)

    # 新增節點用於在評估過程中比較 logits 和 ground-truth labels .
    eval_correct = mnist_softmax_2hidden.evaluation(logits, labels_placeholder)

    # 基於 TF collection of Summaries構建彙總張量.
    summary = tf.summary.merge_all()

    # 新增變數初始化節點(variable initializer Op).
    init = tf.global_variables_initializer()

    # 建立一個 saver 用於寫入 訓練過程中的模型的檢查點檔案(checkpoints).
    saver = tf.train.Saver(max_to_keep=5)

    # 建立一個會話用來執行計算圖中的節點
    sess = tf.Session()

    # 例項化一個 SummaryWriter 輸出 summaries 和 Graph.
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # And then after everything is built:

    # 執行初始化節點來初始化所有變數(Variables).
    sess.run(init)

    # 開啟訓練迴圈.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # 使用真實的影象和類標籤資料集填充 feed dictionary
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # 在當前批次樣本上把模型執行一步(run one step).
      # 返回值是從`train_op`和`loss`節點拿到的activations
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # 每隔100個批次就寫入summaries並輸出overview
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # 更新事件檔案.還是呼叫sess.run()方法
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, global_step=step)
        summary_writer.flush()

      # 週期性的儲存一個檢查點檔案並評估當前模型的效能
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)
        # 在所有訓練集上評估模型
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # 在驗證集上評估模型.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # 在測試集上評估模型.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)

# 建立日誌資料夾,啟動訓練過程
def main(_):
  if tf.gfile.Exists(FLAGS.log_dir):
    tf.gfile.DeleteRecursively(FLAGS.log_dir)
  tf.gfile.MakeDirs(FLAGS.log_dir)
  run_training()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=2000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )
  parser.add_argument(
      '--input_data_dir',
      type=str,
      default='./data',
      help='Directory to put the input data.'
  )
  parser.add_argument(
      '--log_dir',
      type=str,
      default='./logs',
      help='Directory to put the log data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )

  # 把模型的(超)引數全部解析到全域性變數FLAGS裡面
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

參考資料