1. 程式人生 > >TensorFlow API之tf.estimator.Estimator

TensorFlow API之tf.estimator.Estimator

tf.estimator.Estimator

Estimator class訓練和測試TF模型。Estimator物件封裝好通過model_fn指定的模型,給定輸入和其它超引數,返回ops執行training, evaluation or prediction. 所有的輸出(包含checkpoints, event files, etc.)被寫入model_dir

屬性

  • config 
    傳入model_fn,如果model_fn有引數named “config”
  • model_dir
  • model_fn  The model_fn with following signature: def model_fn(features, labels, mode, config)
  • params

方法

  • __init__
__init__(
    model_fn,
    model_dir=None,
    config=None,
    params=None # 將要傳入model_fn的超引數字典
)
  • evaluate

對訓練模型評價

evaluate(
    input_fn, # 輸入函式,返回元組features和labels
    steps=None,
    hooks=None, # List of SessionRunHook subclass instances
    checkpoint_path=None, # if none, 用model_dir中latest checkpoint
    name=None
)
  • export_savemodel  匯出inference graph作為一個SavedModel
export_savedmodel(
    export_dir_base, # 目錄
    serving_input_receiver_fn, # 返回ServingInputReceiver的函式
    assets_extra=None,
    as_text=False,
    checkpoint_path=None
)
  • get_variable_names

    get_variable_names()  返回模型中所有變數名字的列表

  • get_variable_value(name)  根據變數name返回value

  • latest_checkpoint()  在model_dir中找到最近儲存的checkpoint

  • predict  根據給定的features產生預測

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None
)
  • train

給定訓練資料後訓練model

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)