1. 程式人生 > >TensorFlow筆記——(1)理解tf.control_dependencies與control_flow_ops.with_dependencies

TensorFlow筆記——(1)理解tf.control_dependencies與control_flow_ops.with_dependencies

引言

我們在實現神經網路的時候經常會看到tf.control_dependencies的使用,但是這個函式究竟是什麼作用,我們應該在什麼情況下使用呢?今天我們就來一探究竟。

理解

其實從字面上看,control_dependencies 是控制依賴的意思,我們可以大致推測出來,這個函式應該使用來控制就算圖節點之間的依賴的。其實正是如此,tf.control_dependencies()是用來控制計算流圖的,給圖中的某些節點指定計算的順序。

原型分析

tf.control_dependencies(self, control_inputs)
 arguments:control_inputs: A list
of `Operation` or `Tensor` objects which must be executed or computed before running the operations defined in the context. (注意這裡control_inputs是listreturn: A context manager that specifies control dependencies for all operations constructed within the context.

通過以上的解釋,我們可以知道,該函式接受的引數control_inputs,是Operation或者Tensor構成的list。返回的是一個上下文管理器,該上下文管理器用來控制在該上下文中的操作的依賴。也就是說,上下文管理器下定義的操作是依賴control_inputs中的操作的,control_dependencies用來控制control_inputs中操作執行後,才執行上下文管理器中定義的操作。

例子1

如果我們想要確保獲取更新後的引數,name我們可以這樣組織我們的程式碼。

opt = tf.train.Optimizer().minize(loss)

with tf.control_dependencies([opt]): #先執行opt
  updated_weight = tf.identity(weight)  #再執行該操作

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  sess.run(updated_weight, feed_dict={...}) # 這樣每次得到的都是更新後的weight

control_flow_ops.with_dependencies

除了常用tf.control_dependencies()我們還會看到,control_flow_ops.with_dependencies(),其實連個函式都可以實現依賴的控制,只是實現的方式不太一樣。

with_dependencies(dependencies, output_tensor, name=None)
Produces the content of `output_tensor` only after `dependencies`.
所有的依賴操作完成後,計算output_tensor並返回
  In some cases, a user may want the output of an operation to be
  consumed externally only after some other dependencies have run
  first. This function ensures returns `output_tensor`, but only after all
  operations in `dependencies` have run. Note that this means that there is
  no guarantee that `output_tensor` will be evaluated after any `dependencies`
  have run.

  See also @{tf.tuple$tuple} and @{tf.group$group}.

  Args:
    dependencies: Iterable of operations to run before this op finishes.
    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
    name: (Optional) A name for this operation.

  Returns:
    Same as `output_tensor`.

  Raises:
    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 

例子2

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) #從一個集合中取出變數,返回的是一個列表
......
total_loss, clones_gradients = model_deploy.optimize_clones(
            clones,
            optimizer,
            var_list=variables_to_train)
......
# tf.group()將多個tensor或者op合在一起,然後進行run,返回的是一個op
update_op = tf.group(*update_ops)
train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                          name='train_op')

參考文件