1. 程式人生 > >tensorflow__第一章:命令列引數的設定(flags=tf.app.flags)

tensorflow__第一章:命令列引數的設定(flags=tf.app.flags)

在執行main函式之前首先進行flags的解析,也就是說TensorFlow通過設定flags來傳遞tf.app.run()所需要的引數,我們可以直接在程式執行前初始化flags,也可以在執行程式的時候設定命令列引數來達到傳參的目的。

tf.app.flags的使用

flags = tf.app.flags flags.DEFINE_integer("epoch", 1000, "Epoch to train [25]")

flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")

flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")

flags.DEFINE_integer("train_size", 256, "The size of train images [np.inf]")

flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")

flags.DEFINE_string("dataset", "mnist", "The name of dataset [celebA, mnist, lsun]")

flags.DEFINE_boolean("train", True, "True for training, False for testing [False]")

FLAGS = flags.FLAGS

在類的初始化時可以賦值給類中的成員,例如:

dcgan = DCGAN(

sess,

input_width=FLAGS.input_width,

input_height=FLAGS.input_height,

output_width=FLAGS.output_width,

output_height=FLAGS.output_height,

batch_size=FLAGS.batch_size,

sample_num=FLAGS.batch_size,

y_dim=10,

dataset_name=FLAGS.dataset,

input_fname_pattern=FLAGS.input_fname_pattern,

crop=FLAGS.crop,

checkpoint_dir=FLAGS.checkpoint_dir,

sample_dir=FLAGS.sample_dir)

命令列的命名格式:

#第一個是引數名稱,第二個引數是預設值,第三個是引數描述

tf.app.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")

tf.app.flags.DEFINE_integer('int_name', 10,"descript2")

tf.app.flags.DEFINE_boolean('bool_name', False, "descript3")

FLAGS = tf.app.flags.

FLAGS #必須帶引數,否則:'TypeError: main() takes no arguments (1 given)';

main的引數名隨意定義,無要求 def main(_): print(FLAGS.str_name) print(FLAGS.int_name) print(FLAGS.bool_name) if __name__ == '__main__':#避免出現import的時候呼叫main() tf.app.run() #執行main函式