1. 程式人生 > >tf.estimator.Estimator

tf.estimator.Estimator

tf.estimator.Estimator

簡單介紹

是一個class 所以需要初始化,作用是用來 訓練和評價 tensorflow 模型的
Estimator物件包裝由一個名為model_fn函式指定的模型,model_fn在給定輸入和許多其他引數的情況下,返回執行訓練、評估或預測所需的操作。所有輸出(checkpoints, event files, etc.等)都寫入model_dir或其子目錄。如果沒有設定model_dir,則使用臨時目錄。

初始化

__init__(
    model_fn,
    model_dir=None,
    config=
None, params=None, warm_start_from=None ) ''' Args: model_fn: Model function. Follows the signature: Args: features: 是從 input_fn中返回的詞典tensor 或者 單個tensor ;其實質就是模型的輸入(以前我們都是用tf.placeholder輸入的,這裡使用input_fn 函式返回) This is the first item returned from the input_fn labels: 是從 input_fn中返回的詞典tensor 或者 單個tensor,注意,如果mode=tf.estimator.ModeKeys.PREDICT(就是在預測的時候), labels將會被設定為None This is the second item returned from the input_fn mode: Optional. Specifies if this training, evaluation or prediction. See tf.estimator.ModeKeys. params: Optional dict of hyperparameters.接受初始化Estimator例項時的引數params config: Optional estimator.RunConfig object.接受初始化Estimator例項時的引數config 或者一個預設的值. Allows setting up things in your model_fn based on configuration such as num_ps_replicas, or model_dir. Returns: tf.estimator.EstimatorSpec 這裡一定要注意 返回的是EstimatorSpec例項 model_dir: 輸出路徑,有關模型的輸出的一切東西,全部輸出在這裡 config: 這個是一個類,是官方固定的配置引數,如果使用者覺得,不能滿足使用,需要新增自己的引數,可以使用下面的這個引數params params: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a tf.estimator.WarmStartSettings object to fully configure warm-starting. If the string filepath is provided instead of a tf.estimator.WarmStartSettings, then all variables are warm-started, and it is assumed that vocabularies and tf.Tensor names are unchanged. '''

重點圈出

The config argument can be passed tf.estimator.RunConfig object containing information about the execution environment. It is passed on to the model_fn, if the model_fn has a parameter named “config” (and input functions in the same manner). If the config parameter is not passed, it is instantiated by the Estimator. Not passing config means that defaults useful for local execution are used. Estimator makes config available to the model (for instance, to allow specialization based on the number of workers available), and also uses some of its fields to control internals, especially regarding checkpointing.

The params argument contains hyperparameters. It is passed to the model_fn, if the model_fn has a parameter named “params”, and to the input functions in the same manner. Estimator only passes params along, it does not inspect it. The structure of params is therefore entirely up to the developer.

方法

train 方法

從input_fn 獲取資料,用來訓練模型

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

'''
Args:
	input_fn: A function that provides input data for training as minibatches. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.
	hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop.
	steps: Number of steps for which to train the model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. steps works incrementally. If you call two times train(steps=10) then training occurs in total 20 steps. If OutOfRange or StopIteration occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set max_steps instead. If set, max_steps must be None.
	max_steps: Number of total steps for which to train model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. If set, steps must be None. If OutOfRange or StopIteration occurs in the middle, training stops before max_steps steps. Two calls to train(steps=100) means 200 training iterations. On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.
	saving_listeners: list of CheckpointSaverListener objects. Used for callbacks that run immediately before or after checkpoint savings.
Returns:
	self, for chaining.

'''

主要引數說明

input_fn:是一個為訓練提供輸入資料的函式(每次提供一個batch_size的資料),其返回的是的格式是(features,labels),正好作為mode_fn的輸入,其返回的格式應該是下列之一:

  1. tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels)
  2. A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor

max_steps:最大訓練多少step(也就是訓練多少個batch_size),當我們暫停後,繼續訓練程式會檢測目前已經訓練的步數是否大於max_steps若大於等於,那麼就不會繼續訓練(On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.

step:會在原來的基礎上,繼續“增長式”訓練,例如你呼叫了兩次train(input_fn,step=10),那麼模型就相當於訓練了20個迭代

evaluate 方法

Evaluates the model given evaluation data input_fn.
For each step, calls input_fn, which returns one batch of data. Evaluates until: - steps batches are processed, or - input_fn raises an end-of-input exception獲取input_fn返回的資料並輸入到模型中,用來評價模型每一步都呼叫一次input_fn,其返回one batch of data,知道等於steps 或者input_fn raises an end-of-input exception

evaluate(
    input_fn,
    steps=None,
    hooks=None,
    checkpoint_path=None,
    name=None
)

'''
Args:
		input_fn: A function that constructs the input data for evaluation. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.
		steps: Number of steps for which to evaluate model. If None, evaluates until input_fn raises an end-of-input exception.
		hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the evaluation call.
		checkpoint_path: Path of a specific checkpoint to evaluate. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, evaluation is run with newly initialized Variables instead of ones restored from checkpoint.
		name: Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard.

Returns:
		A dict containing the evaluation metrics specified in model_fn keyed by name, as well as an entry global_step which contains the value of the global step for which this evaluation was performed. For canned estimators, the dict contains the loss (mean loss per mini-batch) and the average_loss (mean loss per sample). Canned classifiers also return the accuracy. Canned regressors also return the label/mean and the prediction/mean.
'''

引數說明

具體的引數和train方法類似,就不說了,這裡主要說一下 這個方法的返回(return)
返回的是一個詞典,是在mode_fn中提前指定好的,同時還會返回執行了多少step
例如在model_fn函式中一般有如下類似定義:

    estim_specs=tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=pred_classes,
        loss=loss_op,
        train_op=train_op,
        eval_metric_ops={"accuracy":acc_op})

中的 eval_metric_ops={“accuracy”:acc_op}),最後會輸出類似這種

{'accuracy': 0.9192, 'loss': 0.28470048, 'global_step': 1000}

predict方法

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None,
    yield_single_examples=True
)

'''
Args:
	input_fn: A function that constructs the features. Prediction continues until input_fn raises an end-of-input exception (tf.errors.OutOfRangeError or StopIteration). See Premade Estimators for more information. The function should construct and return one of the following:
	
	A tf.data.Dataset object: Outputs of Dataset object must have same constraints as below.
	features: A tf.Tensor or a dictionary of string feature name to Tensor. features are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.
	A tuple, in which case the first item is extracted as features.
	predict_keys: list of str, name of the keys to predict. It is used if the tf.estimator.EstimatorSpec.predictions is a dict. If predict_keys is used then rest of the predictions will be filtered from the dictionary. If None, returns all.
	
	hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the prediction call.
	
	checkpoint_path: Path of a specific checkpoint to predict. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, prediction is run with newly initialized Variables instead of ones restored from checkpoint.
	
	yield_single_examples: If False, yields the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors whose first dimension is not equal to the batch size.
'''

說明

給定輸入,返回在model_fn中指定要輸出的內容tf.estimator.EstimatorSpec(mode,predictions=pred_classes)

    ....
    ....
    
    pred_classes=tf.argmax(logits,axis=1)
    pred_probas=tf.nn.softmax(logits)
    
    #PREDICTS
    if mode==tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode,predictions=pred_classes)
    .....
    ......
        

具體引數和trian 方法的引數基本相同,就不多說,這裡重點講一下下面幾個:
predict_keys: 是一個str型別的list,如果使用這個predict_keys,那麼模型只會返回predictions 中和predict_keys相同的key的值
**checkpoint_path:**要預測的特定檢查點的路徑。如果沒有,則使用model_dir中的最新檢查點。如果在model_dir中沒有檢查點,則使用新初始化的變數而不是從檢查點恢復的變數執行預測
yield_single_examples: 如果為False,則生成model_fn返回的整個批,而不是將批分解為單個元素。如果model_fn返回其第一維不等於批處理大小的一些張量,則這很有用。