tensorflow命令列引數:tf.app.flags.DEFINE_string、tf.app.flags.DEFINE_integer、tf.app.flags.DEFINE_boolean
阿新 • • 發佈:2019-01-02
tf 中定義了 tf.app.flags.FLAGS ,用於接受從終端傳入的命令列引數,相當於對Python中的命令列引數模組optpars(參考: python中處理命令列引數的模組optpars )做了一層封裝。
optpars中的引數型別是通過引數 “type=xxx” 定義的,tf中每個合法型別都有對應的 “DEFINE_xxx”函式。常用:
定義一個名稱是 "ckpt_path" 的變數,預設值是 ckpt_path = 'model/model.ckpt-100000',描述資訊表明這是一個用於儲存節點資訊的路徑。
optpars中的引數型別是通過引數 “type=xxx” 定義的,tf中每個合法型別都有對應的 “DEFINE_xxx”函式。常用:
- tf.app.flags.DEFINE_string() :定義一個用於接收 string 型別數值的變數;
- tf.app.flags.DEFINE_integer() : 定義一個用於接收 int 型別數值的變數;
- tf.app.flags.DEFINE_float() : 定義一個用於接收 float 型別數值的變數;
- tf.app.flags.DEFINE_boolean() : 定義一個用於接收 bool 型別數值的變數;
“DEFINE_xxx”函式帶3個引數,分別是變數名稱,預設值,用法描述,例如:
tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''Checkpoint directory to restore''')
定義一個名稱是 "ckpt_path" 的變數,預設值是 ckpt_path = 'model/model.ckpt-100000',描述資訊表明這是一個用於儲存節點資訊的路徑。
example:
# -*- coding=utf-8 -*- import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''模型儲存路徑''') tf.app.flags.DEFINE_float('learning_rate',0.0001,'''初始學習率''') tf.app.flags.DEFINE_integer('train_steps', 50000, '''總的訓練輪數''') tf.app.flags.DEFINE_boolean('is_use_gpu', False, '''是否使用GPU''') print '模型儲存路徑: {}'.format(FLAGS.ckpt_path) print '初始學習率: {}'.format(FLAGS.learning_rate) print '總的訓練次數: {}'.format(FLAGS.train_steps) print '是否使用GPU: {}'.format(FLAGS.is_use_gpu)
使用 '-h' 指令檢視幫助資訊:
python flags_test.py -h
按預設設定執行程式:
傳入使用者自定義的命令列引數:
python flags_test.py --ckpt_path abc/cba --learning_rate 0.001 --train_steps 10000 --is_use_gpu True