1. 程式人生 > >TensorFlow中的小知識:tf.flags.DEFINE_xxx()

TensorFlow中的小知識:tf.flags.DEFINE_xxx()

讀別人家的程式碼的時候經常看到這個,結果兩三天不看居然忘記了,這腦子絕對上鏽了,決定記下來免得老是查來查去的。。。
內容包含如下幾個我們經常看到的幾個函式:
①tf.flags.DEFINE_xxx()
②FLAGS = tf.flags.FLAGS
③FLAGS._parse_flags()

簡單的說:

用於幫助我們新增命令列的可選引數。
也就是說利用該函式我們可以實現在命令列中選擇需要設定的引數來執行程式,
可以不用反覆修改原始碼中的引數,直接在命令列中進行引數的設定。

舉個栗子:

程式train.py檔案中的小部分程式碼如下所示:

FLAGS = tf.flags
.FLAGS tf.flags.DEFINE_string('name', 'default', 'name of the model') tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch') tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq') tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm') tf.flags.DEFINE_integer('num_layers'
, 2, 'number of lstm layers') tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding') tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding') tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate') tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training'
) tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file') tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train') tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps') tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps') tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number') #全域性引數設定,顯示在命令列

在命令列中我們為了執行train.py檔案,在命令列中輸入:

python train.py \
  --input_file data/shakespeare.txt  \
  --name shakespeare \
  --num_steps 50 \
  --num_seqs 32 \
  --learning_rate 0.01 \
  --max_steps 20000

通過輸入不同的檔名、引數,可以快速完成程式的調參和更換訓練集的操作,不需要進入原始碼中更改。

備註:在此感謝上述程式碼的作者

實踐操作一下:

現在我們有如下程式碼:

import tensorflow as tf
#取上述程式碼中一部分進行實驗
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')

#通過print()確定下面內容的功能
FLAGS = tf.flags.FLAGS #FLAGS儲存命令列引數的資料
FLAGS._parse_flags() #將其解析成字典儲存到FLAGS.__flags中
print(FLAGS.__flags)

print(FLAGS.num_seqs)

print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")

嘗試執行一下上述程式碼瞭解其各行程式碼的功能,可能因為tensorflow版本原因出現報錯現象。
檢視解決辦法可點選連結