TensorFlow中的那些高階API
摘要: 在這篇文章中,我們將看到一個使用了最新高階構件的例子,包括Estimator(估算器)、Experiment(實驗)和Dataset(資料集)。值得注意的是,你可以獨立地使用Experiment和Dataset。不妨進來看看作者是如何玩轉這些高階API的。
TensorFlow擁有很多庫,比如Keras、TFLearn和Sonnet,對於模型訓練來說,使用這些庫比使用低階功能更簡單。儘管Keras的API目前正在新增到TensorFlow中去,但TensorFlow本身就提供了一些高階構件,而且最新的1.3版本中也引入了一些新的構件。
在這篇文章中,我們將看到一個使用了這些最新的高階構件的例子,包括Estimator(估算器)、Experiment(實驗)和Dataset(資料集)。值得注意的是,你可以獨立地使用Experiment和Dataset。我在這裡假設你已經瞭解TensorFlow的基礎知識;如果沒有的話,那麼TensorFlow官網上提供的教程值得學習。
Experiment、Estimator和DataSet框架以及它們之間的互動。
我們在本文中將使用MNIST作為資料集。這是一個使用起來很簡單的資料集,可以從TensorFlow官網獲取到。你可以在這個gist中找到完整的程式碼示例。使用這些框架的其中一個好處是,我們不需要直接處理圖和會話。
Estimator(估算器)類
Estimator類代表了一個模型,以及如何對這個模型進行訓練和評估。我們可以像下面這段程式碼建立一個Estimator:
return tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
config=run_config # RunConfig)
要建立Estimator,需要傳入一個模型函式、一組引數和一些配置。
-
傳入的**引數**應該是模型超引數的一個集合。這可以是一個dictionary,但是我們將在這個例子中把它表示成一個HParams物件,就像namedtuple一樣。
-
傳入的**配置**用於指定如何執行訓練和評估,以及在哪裡儲存結果。這個配置是一個RunConfig物件,該物件會把模型執行環境相關的資訊告訴Estimator。
-
模型函式是一個Python函式,它根據給定的輸入構建模型。
模型函式
模型函式是一個Python函式,並作為一級函式傳遞給Estimator。稍後我們會看到,TensorFlow在其他地方也使用了一級函式。將模型表示為一個函式的好處是可以通過例項化函式來多次建立模型。模型可以在訓練過程中用不同的輸入重新建立,例如,在訓練過程中執行驗證測試。
模型函式把**輸入特徵**作為引數,將相應的**標籤**作為張量。它也能以某種方式來告知使用者模型是在訓練、評估或是在執行推理。模型函式的最後一個引數是**超引數**集合,它們與傳遞給Estimator的超引數集合相同。模型函式返回一個**EstimatorSpec**物件,該物件定義了一個完整的模型。
EstimatorSpec物件用於對操作進行預測、損失、訓練和評估,因此,它定義了一個用於訓練、評估和推理的完整的模型圖。由於EstimatorSpec只可用於常規的TensorFlow操作,因此,我們可以使用像TF-Slim這樣的框架來定義模型。
Experiment(實驗)類
Experiment類定義瞭如何訓練模型,它與Estimator完美地整合在一起。我們可以像如下程式碼建立一個Experiment物件:
experiment = tf.contrib.learn.Experiment( estimator=estimator, # Estimator
train_input_fn=train_input_fn, # First-class function
eval_input_fn=eval_input_fn, # First-class function
train_steps=params.train_steps, # Minibatch steps
min_eval_frequency=params.min_eval_frequency, # Eval frequency
train_monitors=[train_input_hook], # Hooks for training
eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps=None # Use evaluation feeder until its empty)
以下幾種情況會把Experiment物件作為輸入:
-
一個**estimator**(例如我們上面定義的)。
-
作為一級函式**訓練和評估資料**。這裡使用了與前面提到的模型函式相同的概念。如果需要的話,通過傳入函式而不是操作,可以重新建立輸入圖。稍後我們還會談到這個。
-
訓練和評估hook(鉤子)。鉤子可用於儲存或監視特定的內容,或者在圖或會話中設定某些操作。例如,我們將其傳入到操作中,幫助初始化資料載入器。
-
描述需要訓練多久以及何時評估的各種引數。
一旦定義了experiment,我們就可以像下面這段程式碼那樣使用learn_runner.run來執行它訓練和評估模型:
learn_runner.run( experiment_fn=experiment_fn, # First-class function
run_config=run_config, # RunConfig
schedule="train_and_evaluate", # What to run
hparams=params # HParams)
與模型函式和資料函式一樣,learn_runner
將一個建立experiment的函式作為引數傳入。
Dataset(資料集)類
我們將使用Dataset類和相應的Iterator來表示資料的訓練和評估,以及建立在訓練過程中迭代資料的資料饋送器。 在本示例中,我們將使用在Tensorflow中可用的MNIST資料,併為其構建一個Dataset包裝。例如,我們將把訓練輸入資料表示為:
# Define the training inputsdef get_train_inputs(batch_size, mnist_data):
"""Return the input function to get the training data.
Args:
batch_size (int): Batch size of training iterator that is returned
by the input function.
mnist_data (Object): Object holding the loaded mnist data.
Returns:
(Input function, IteratorInitializerHook):
- Function that returns (features, labels) when called.
- Hook to initialise input iterator.
"""
iterator_initializer_hook = IteratorInitializerHook() def train_inputs():
"""Returns training set as Operations.
Returns:
(features, labels) Operations that iterate over the dataset
on every evaluation
"""
with tf.name_scope('Training_data'): # Get Mnist data
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = mnist_data.train.labels # Define placeholders
images_placeholder = tf.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.placeholder(
labels.dtype, labels.shape) # Build dataset iterator
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next() # Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \ lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels}) # Return batched (features, labels)
return next_example, next_label # Return function and hook
return train_inputs, iterator_initializer_hook
呼叫這個get_train_inputs
將返回一個一級函式,用於在TensorFlow圖中建立資料載入操作,以及返回一個用於初始化迭代器的Hook
。
本示例中使用的MNIST資料最初是一個Numpy陣列。我們建立了一個佔位符張量來獲取資料;使用佔位符的目的是為了避免資料的複製。接下來,我們在from_tensor_slices的幫助下建立一個切片資料集。我們要確保該資料集可以執行無限次數,並且資料被重新洗牌並放入指定大小的批次中。
要迭代資料,就需要從資料集中建立一個迭代器。由於我們正在使用佔位符,因此需要使用NumPy資料在相關會話中對佔位符進行初始化。可以通過建立一個可初始化的迭代器來實現這個。在建立圖的時候,將建立一個自定義的IteratorInitializerHook物件來初始化迭代器:
class IteratorInitializerHook(tf.train.SessionRunHook):
"""Hook to initialise data iterator after Session is created."""
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None
def after_create_session(self, session, coord):
"""Initialise the iterator after the session has been created."""
self.iterator_initializer_func(session)
IteratorInitializerHook繼承自SessionRunHook。這個鉤子將在相關會話建立後立即呼叫after_create_session,並使用正確的資料初始化佔位符。這個鉤子由我們的get_train_inputs函式返回,並在建立時傳遞給Experiment物件。
train_inputs函式返回的資料載入操作是TensorFlow的操作,該操作每次評估時都會返回一個新的批處理。
執行程式碼
現在,我們已經定義了所有內容,可以使用下面這個命令執行程式碼了:
python mnist_estimator.py --model_dir ./mnist_training --data_dir ./mnist_data
如果不傳入引數,它將使用檔案開頭的預設標誌來確定資料和模型儲存的位置。
在訓練過程中,在終端上會輸出這段時間內的全域性步驟、損失和準確性等資訊。除此之外,Experiment和Estimator框架將記錄TensorBoard視覺化的某些統計資訊。如果我們執行這個命令:
tensorboard --logdir='./mnist_training'
那麼我們可以看到所有的訓練統計資料,如訓練損失、評估準確性、每個步驟的時間,以及模型圖。
TensorBoard視覺化中的評估準確度
我寫這篇文章,是因為我在編寫程式碼示例時,無法找到有關Tensorflow Estimator 、Experiment和Dataset框架太多的資訊和示例。我希望這篇文章能向你簡要介紹一下這些框架是如何工作的,它們採用了什麼樣的抽象方法以及如何使用它們。如果你對使用這些框架感興趣,下面我將介紹一些注意點和其他的文件。
有關Estimator、Experiment和Dataset框架的注意點
-
有一篇名為《TensorFlow Estimators:掌握高階機器學習框架中的簡單性與靈活性》的文章描述了Estimator框架的高級別設計。
-
TensorFlow官網上有更多有關使用Dataset API的文件。
-
有2個版本的Estimator類。在這個例子中,我們使用的是tf.estimator.Estimator,但在tf.contrib.learn.Estimator中還有一個較老的不穩定版本。
-
也有2個版本的RunConfig類。當我們使用tf.contrib.learn.RunConfig的時候,另外還有一個tf.estimator.RunConfig的版本。我無法讓後者與Experiment框架結合在一起,所以我堅持使用tf.contrib版本。
-
雖然我們在這個例子中沒有使用它們,但是Estimator框架定義了典型模型(如分類器和迴歸器)的預定義估算器。這些預定義的估算器使用起來很簡單,並附有詳細的教程。
-
TensorFlow還定義了模型“頭”的抽象,這個“頭”是架構的上層,定義了損失、評估和訓練操作。這個“頭”負責定義模型函式和所有必需的操作。你可以在tf.contrib.learn.Head中找到一個版本。在較新的Estimator框架中也有一個原型版本。在這個例子中我們不打算使用,因為它的開發非常不穩定。
-
本文使用了TensorFlow slim框架來定義模型的架構。 Slim是一個用於定義TensorFlow中複雜模型的輕量級庫。它定義了預定義的架構和預先訓練的模型。
-
更復雜的例子,請參見:https://gist.github.com/peterroelants/9956ec93a07ca4e9ba5bc415b014bcca#file-mnist_estimator-py