1. 程式人生 > >TensorFlow - TF-Slim 提供了關於變數的控制與管理封裝函式 - Variables.

TensorFlow - TF-Slim 提供了關於變數的控制與管理封裝函式 - Variables.

感謝作者: http://www.aiuai.cn/aifarm316.html
TensorFlow - TF-Slim 提供了關於變數的控制與管理封裝函式 - Variables.
包括變數恢復函式,如get_variables, get_variables_to_restore 等.

Variables 函式主要有:

  • add_model_variable,
  • assert_global_step,
  • assert_or_get_global_step,
  • assign_from_checkpoint,
  • assign_from_checkpoint_fn,
  • assign_from_values,
  • assign_from_values_fn,
  • create_global_step,
  • filter_variables,
  • get_global_step,
  • get_or_create_global_step,
  • get_local_variables,
  • get_model_variables,
  • get_trainable_variables,
  • get_unique_variable,
  • get_variables_by_name,
  • get_variables_by_suffix,
  • get_variable_full_name,
  • get_variables_to_restore,
  • get_variables,
  • global_variable,
  • local_variable,
  • model_variable,
  • variable,
  • VariableDeviceChooser,
  • zero_initializer - 全部初始化為 0

1. zero_initializer

def zero_initializer(ref, use_locking=True, name="zero_initializer"):
  """
  對 ref 全部初始化未 0. 
  輸入的 ref tensor 應該是未初始化的.
  如果 ref tensor 已經初始化,返回 ValueError.
  用於節省初始化所需記憶體.

  Args:
    ref: ref of the tensor need to be zero initialized.
    name: optional name for this operation.
  Returns:
    ref that initialized.
  Raises:
    ValueError: If ref tensor is initialized.
  """
  loader.load_op_library(resource_loader.get_path_to_datafile("_variable_ops.so"))
  if resource_variable_ops.is_resource_variable(ref):
      return gen_variable_ops.zero_var_initializer(ref.handle, 
                                                   shape=ref.shape, 
                                                   dtype=ref.dtype, name=name)
  else:
      return gen_variable_ops.zero_initializer(ref, name=name)

2. assert_global_step(已廢棄)

@deprecated(None, "Please switch to tf.train.assert_global_step")
def assert_global_step(global_step_tensor):
    training_util.assert_global_step(global_step_tensor)

3. assert_or_get_global_step

def assert_or_get_global_step(graph=None, global_step_tensor=None):
  """
  驗證 global step tensor 是否有效;
  如果 global_step_tensor=None,則返回 1.
  如果 global_step_tensor 不是 None,則驗證 global step tensor(採用global_step_tensor).
  否則,採用 get_global_step 來查詢 global step tensor 並返回.

  Args:
    graph: The graph to find the global step tensor for.
    global_step_tensor: The tensor to check for suitability as a global step.
      If None is given (the default), find a global step tensor.
  Returns:
    A tensor suitable as a global step, or `None` if none was provided and none
    was found.
  """
  if global_step_tensor is None:
      # Get the global step tensor the same way the supervisor would.
      global_step_tensor = get_global_step(graph)
  else:
      assert_global_step(global_step_tensor)
  return global_step_tensor

4. get_global_step(已廢棄)

@deprecated(None, "Please switch to tf.train.get_global_step")
def get_global_step(graph=None):
  return training_util.get_global_step(graph)

5. create_global_step(已廢棄)

@deprecated(None, "Please switch to tf.train.create_global_step")
def create_global_step(graph=None):
  """Create global step tensor in graph.
  This API is deprecated. Use core framework training version instead.
  Args:
    graph: The graph in which to create the global step tensor. If missing,
      use default graph.
  Returns:
    Global step tensor.
  Raises:
    ValueError: if global step tensor is already defined.
  """
  return training_util.create_global_step(graph)

6. get_or_create_global_step(已廢棄)

@deprecated(None, "Please switch to tf.train.get_or_create_global_step")
def get_or_create_global_step(graph=None):
  """Returns and create (if necessary) the global step tensor.
  Args:
    graph: The graph in which to create the global step tensor. If missing, use
      default graph.
  Returns:
    The global step tensor.
  """
  return training_util.get_or_create_global_step(graph)

7. local_variable

def local_variable(initial_value, validate_shape=True,
                   name=None, use_resource=None):
  """
  建立帶 value 的變數,並新增到 `GraphKeys.LOCAL_VARIABLES`.

  Args:
    initial_value: See variables.Variable.__init__.
    validate_shape: See variables.Variable.__init__.
    name: See variables.Variable.__init__.
    use_resource: If `True` use a ResourceVariable instead of a Variable.
  Returns:
    New variable.
  """
  return variable_scope.variable(
      initial_value, trainable=False,
      collections=[ops.GraphKeys.LOCAL_VARIABLES],
      validate_shape=validate_shape,
      use_resource=use_resource,
      name=name)

8. global_variable

def global_variable(initial_value, validate_shape=True,
                    name=None, use_resource=None):
  """
  建立帶 value 的變數,並新增到 `GraphKeys.GLOBAL_VARIABLES`.

  Args:
    initial_value: See variables.Variable.__init__.
    validate_shape: See variables.Variable.__init__.
    name: See variables.Variable.__init__.
    use_resource: If `True` use a ResourceVariable instead of a Variable.
  Returns:
    New variable.
  """
  return variable_scope.variable(
      initial_value, trainable=False,
      collections=[ops.GraphKeys.GLOBAL_VARIABLES],
      validate_shape=validate_shape,
      use_resource=use_resource,
      name=name)

9. variable

@contrib_add_arg_scope
def variable(name, shape=None, dtype=None, initializer=None,
             regularizer=None, trainable=True, collections=None,
             caching_device=None, device=None,
             partitioner=None, custom_getter=None, use_resource=None):
  """
  根據引數返回已有變數,護著建立一個新變數.

  Args:
    name: 新變數或者已有變數的名字
    shape: 新變數或者已有變數的 shape
    dtype: 新變數或者已有變數的的 type (預設是 `DT_FLOAT`).
    initializer: 建立新變數的初始化方式
    regularizer: a (Tensor -> Tensor or None) function; 應用到新建立變數的結果會新增到集合 collection
        GraphKeys.REGULARIZATION_LOSSES 中,且可以用於正則化.
    trainable: 如果值為 True,則變數會新增到 Graph 集合 GraphKeys.TRAINABLE_VARIABLES 中(參考 `tf.Variable`).
    collections: Variable 被新增到的集合名字列表. 如果 collections=None,則預設新增到集合 tf.GraphKeys.GLOBAL_VARIABLES 中.
    caching_device: 可選裝置字串或描述 Variable 快取所在的函式. 預設為 Variable 的裝置device.
    device: 放置變數的可選裝置. 可以使一個字串,或用於呼叫來獲取變數所在裝置的函式.
    partitioner: Optional callable that accepts a fully defined `TensorShape`
      and dtype of the `Variable` to be created, and returns a list of
      partitions for each axis (currently only one axis can be partitioned).
    custom_getter: Callable that allows overwriting the internal
      get_variable method and has to have the same signature.
    use_resource: If `True` use a ResourceVariable instead of a Variable.
  Returns:
    The created or existing variable.
  """
  collections = list(collections if collections is not None
                     else [ops.GraphKeys.GLOBAL_VARIABLES])

  # Remove duplicates
  collections = list(set(collections))
  getter = variable_scope.get_variable
  if custom_getter is not None:
    getter = functools.partial(custom_getter,
                               reuse=variable_scope.get_variable_scope().reuse)
  with ops.device(device or ''):
    return getter(name, shape=shape, dtype=dtype,
                  initializer=initializer,
                  regularizer=regularizer,
                  trainable=trainable,
                  collections=collections,
                  caching_device=caching_device,
                  partitioner=partitioner,
                  use_resource=use_resource)

10. model_variable

類似於 variable 函式,針對的是模型變數.

@contrib_add_arg_scope
def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
                   regularizer=None, trainable=True, collections=None,
                   caching_device=None, device=None, partitioner=None,
                   custom_getter=None, use_resource=None):
  """Gets an existing model variable with these parameters or creates a new one.
  Args:
    name: the name of the new or existing variable.
    shape: shape of the new or existing variable.
    dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
    initializer: initializer for the variable if one is created.
    regularizer: a (Tensor -> Tensor or None) function; the result of
        applying it on a newly created variable will be added to the collection
        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
    trainable: If `True` also add the variable to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    collections: A list of collection names to which the Variable will be added.
      Note that the variable is always also added to the
      `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections.
    caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.
    device: Optional device to place the variable. It can be an string or a
      function that is called to get the device for the variable.
    partitioner: Optional callable that accepts a fully defined `TensorShape`
      and dtype of the `Variable` to be created, and returns a list of
      partitions for each axis (currently only one axis can be partitioned).
    custom_getter: Callable that allows overwriting the internal
      get_variable method and has to have the same signature.
    use_resource: If `True` use a ResourceVariable instead of a Variable.
  Returns:
    The created or existing variable.
  """
  collections = list(collections or [])
  collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
  var = variable(name, shape=shape, dtype=dtype,
                 initializer=initializer, regularizer=regularizer,
                 trainable=trainable, collections=collections,
                 caching_device=caching_device, device=device,
                 partitioner=partitioner, custom_getter=custom_getter,
                 use_resource=use_resource)
  return var

11. add_model_variable

def add_model_variable(var):
  """
  新增變數到 GraphKeys.MODEL_VARIABLES 集合.

  Args:
    var: a variable.
  """
  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)

12. get_variables

def get_variables(scope=None, suffix=None,
                  collection=ops.GraphKeys.GLOBAL_VARIABLES):
  """
  獲取變數列表,根據作用域scope 和字尾suffix 來過濾.
  即,獲取指定作用域和字尾的變數列表.

  Args:
    scope: 用於過濾變數的作用域,可以是一個變數作用域或一個字串.
    suffix: 用於過濾變數的字尾.
    collection: 查詢變數所在的集合,預設是集合 `GraphKeys.GLOBAL_VARIABLES`.
  Returns:
    集合中包含指定作用域和字尾的變數列表.
  """
  if isinstance(scope, variable_scope.VariableScope):
    scope = scope.name
  if suffix is not None:
    if ':' not in suffix:
      suffix += ':'
    scope = (scope or '') + '.*' + suffix
  return ops.get_collection(collection, scope)

13. get_model_variables

def get_model_variables(scope=None, suffix=None):
  """
  獲取模型變數列表,根據作用域scope 和字尾suffix 來過濾.
  即,獲取指定作用域和字尾的模型變數列表.

  Args:
      scope: 用於過濾變數的作用域
      suffix: 用於過濾變數的字尾.
  Returns:
      集合中包含指定作用域和字尾的變數列表.
  """
  return get_variables(scope, suffix, ops.GraphKeys.MODEL_VARIABLES)

14. get_local_variables

def get_local_variables(scope=None, suffix=None):
  """
   獲取區域性變數(local variables)列表, 根據作用域scope 和字尾suffix 來過濾.

  Args:
    scope: 用於過濾變數的作用域
    suffix: 用於過濾變數的字尾.
  Returns:
    集合中包含指定作用域和字尾的變數列表.
  """
  return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES)

15. get_trainable_variables

def get_trainable_variables(scope=None, suffix=None):
  """
  獲取可訓練變數(trainable variables)列表, 根據作用域scope 和字尾suffix 來過濾.

  Args:
    scope: 用於過濾變數的作用域
    suffix: 用於過濾變數的字尾.
  Returns:
    集合中包含指定作用域和字尾的變數列表.
  """
  return get_variables(scope, suffix, ops.GraphKeys.TRAINABLE_VARIABLES)

16. get_variables_to_restore

def get_variables_to_restore(include=None, exclude=None):
  """
  待恢復的變數列表.

  Args:
    include: 列表或陣列(list/scope)字串,用於從 VARIABLES 集合中過濾變數為包含(include).
                 如果 include=None,則包含全部變數.
                 (返回所有滿足 include 的變數)
    exclude: a列表或陣列(list/scope)字串,用於從 VARIABLES 集合中過濾變數為排除(exclude). 
                 如果 exclude=None,則不排除任何變數.
  Returns:
  恢復的變數列表
  Raises:
    TypeError: include or exclude is provided but is not a list or a tuple.
  """
  if include is None:
    # Include all variables.
    vars_to_include = get_variables()
  else:
    if not isinstance(include, (list, tuple)):
      raise TypeError('include is provided but is not a list or a tuple.')
    vars_to_include = []
    for scope in include:
      vars_to_include += get_variables(scope)
  vars_to_exclude = set()
  if exclude is not None:
    if not isinstance(exclude, (list, tuple)):
      raise TypeError('exclude is provided but is not a list or a tuple.')
    for scope in exclude:
      vars_to_exclude |= set(get_variables(scope))
  # Exclude the variables in vars_to_exclude
  return [v for v in vars_to_include if v not in vars_to_exclude]

17. get_variables_by_suffix

def get_variables_by_suffix(suffix, scope=None):
  """
  獲取以給定字尾結尾的變數列表.

  Args:
    suffix: 用於返回過濾變數的字尾
    scope: 用於返回過濾變數的作用域
  Returns:
    給定字尾的變數列表.
  """
  return get_variables(scope=scope, suffix=suffix)

18. get_variables_by_name

def get_variables_by_name(given_name, scope=None):
  """
  獲取給定 name 的變數列表.

  Args:
    given_name: 沒有任何作用域的給定的變數名.
    scope: 用於返回過濾變數的作用域
  Returns:
    給定名字和作用域的變數列表
  """
  suffix = '/' + given_name + ':|^' + given_name + ':'
  return get_variables(scope=scope, suffix=suffix)

19. get_unique_variable

def get_unique_variable(var_op_name):
  """
  根據 var_op_name,獲取對應的唯一的變數.

  Args:
    var_op_name: 變數 op 的全名,包括其作用域.
  Returns:
    a tensorflow variable.
  Raises:
    ValueError: if no variable uniquely identified by the name exists.
  """
  candidates = get_variables(scope=var_op_name)
  if not candidates:
    raise ValueError('Couldn\'t find variable %s' % var_op_name)

  for candidate in candidates:
    if candidate.op.name == var_op_name:
      return candidate
  raise ValueError('Variable %s does not uniquely identify a variable' %
                   var_op_name)

20. assign_from_values

def assign_from_values(var_names_to_values):
  """
  根據給定對映mapping,建立分配assignment 操作.
  This function provides a mechanism for performing assignment of variables
  to values in a way that does not fill the graph with large assignment values.

  Args:
    var_names_to_values: 變數名到值的對映.
  Returns:
    assign_op: 分配所有給定變數到請求值的 `Operation`
                        (assigns each of the given variables to the requested values.)
    feed_dict: The feed dictionary to use when evaluating `assign_op`.
  Raises:
    ValueError: if any of the given variable names were not found.
  """
  feed_dict = {}
  assign_ops = []

  for var_name in var_names_to_values:
    var_value = var_names_to_values[var_name]
    var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, var_name)
    if not var:
      raise ValueError('Variable %s wasn\'t found' % var_name)
    elif len(var) > 1:
      # tf.get_collection is just a filter on the prefix: find the exact match:
      found = False
      for v in var:
        if v.op.name == var_name:
          var = v
          found = True
          break

      if not found:
        raise ValueError('Variable %s doesn\'t uniquely identify a variable' %
                         var_name)
    else:
      var = var[0]

    # TODO(nsilberman): ensure placeholder and assign are on the same device.
    # Assign a placeholder to the value that will be filled later.
    placeholder_name = 'placeholder/' + var.op.name
    placeholder_value = array_ops.placeholder(
        dtype=var.dtype.base_dtype,
        shape=var.get_shape(),
        name=placeholder_name)
    assign_ops.append(var.assign(placeholder_value))

    feed_dict[placeholder_value] = var_value.reshape(var.get_shape())

  assign_op = control_flow_ops.group(*assign_ops)
  return assign_op, feed_dict

21. assign_from_values_fn

def assign_from_values_fn(var_names_to_values):
  """
  返回從給定值分配指定變數的函式.

  This function provides a mechanism for performing assignment of variables
  to values in a way that does not fill the graph with large assignment values.

  Args:
    var_names_to_values: A map from variable names to values.
  Returns:
    A function that takes a single argument, a `tf.Session`, that applies the
    assignment operation.
  Raises:
    ValueError: if any of the given variable names were not found.
  """
  assign_op, feed_dict = assign_from_values(var_names_to_values)
  def callback(session):
    return session.run(assign_op, feed_dict)
  return callback

22. get_variable_full_name

# pylint: disable=protected-access
# Currently variable_scope doesn't provide very good APIs to access
# all variables under scope and retrieve and check existing scopes.
def get_variable_full_name(var):
  """
  返回變數的全名.

 對於一般變數(normal Variables),其與 var.op.name 相同.
 對於 slice 或 PartitionedVariables,所有的 slices/partitions 具有相同的名字 name.
 對於二者,可以在斷點檔案中正常使用.

  Args:
    var: A `Variable` object.
  Returns:
    A string that is the full name.
  """
  if var._save_slice_info:
    return var._save_slice_info.full_name
  else:
    return var.op.name

23. assign_from_checkpoint

# TODO(nsilberman): add flag to load exponential moving averages instead
#
# TODO(sguada): Update docs in slim/g3doc/index.md to describe
# the new feature where the var_list dictionary can have values that
# are each a list of Variables.
def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
  """
  建立從斷點檔案中分配指定變數的 op.

  Args:
    model_path: 模型斷點檔案的路徑. 
                        獲取最新的模型斷點檔案:
                            `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
    var_list: A list of (possibly partitioned) `Variable` objects
                or a dictionary mapping names in the checkpoint to the
                corresponding variables or list of variables to initialize
                from that checkpoint value. 
                For partitioned Variables, the
                name in the checkpoint must be the full variable, not the
                name of the partitioned variable, eg. "my_var" rather than
                "my_var/part_4". If empty, returns no_op(), {}.
    ignore_missing_vars: Boolean, 如果為 True,則忽略在斷點中缺失的變數,
                並給出 warning,而不是出錯.
  Returns:
    the restore_op and the feed_dict that need to be run to restore var_list.
  Raises:
    ValueError: If `ignore_missing_vars` is False and the checkpoint specified
        at `model_path` is missing one of the variables in `var_list`.
  """
  # Normalize var_list into a dictionary mapping names in the
  # checkpoint to the list of variables to initialize from that
  # checkpoint variable. Sliced (including partitioned) variables will
  # end up under the same key.
  grouped_vars = {}
  if isinstance(var_list, (tuple, list)):
    for var in var_list:
      ckpt_name = get_variable_full_name(var)
      if ckpt_name not in grouped_vars:
        grouped_vars[ckpt_name] = []
      grouped_vars[ckpt_name].append(var)

  else:
    for ckpt_name, value in var_list.items():
      if isinstance(value, (tuple, list)):
        grouped_vars[ckpt_name] = value
      else:
        grouped_vars[ckpt_name] = [value]

  # 讀取每個斷點元素. Create a placeholder variable and
  # add the (possibly sliced) data from the checkpoint to the feed_dict.
  reader = pywrap_tensorflow.NewCheckpointReader(model_path)
  feed_dict = {}
  assign_ops = []
  for ckpt_name in grouped_vars:
    if not reader.has_tensor(ckpt_name):
      log_str = 'Checkpoint is missing variable [%s]' % ckpt_name
      if ignore_missing_vars:
        logging.warning(log_str)
        continue
      else:
        raise ValueError(log_str)
    ckpt_value = reader.get_tensor(ckpt_name)

    for var in grouped_vars[ckpt_name]:
      placeholder_tensor = array_ops.placeholder(
          dtype=var.dtype.base_dtype,
          shape=var.get_shape(),
          name='placeholder/' + var.op.name)
      assign_ops.append(var.assign(placeholder_tensor))

      if not var._save_slice_info:
        if var.get_shape() != ckpt_value.shape:
          raise ValueError(
              'Total size of new array must be unchanged for %s '
              'lh_shape: [%s], rh_shape: [%s]'
              % (ckpt_name, str(ckpt_value.shape), str(var.get_shape())))

        feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape)
      else:
        slice_dims = zip(var._save_slice_info.var_offset,
                         var._save_slice_info.var_shape)
        slice_dims = [(start, start + size) for (start, size) in slice_dims]
        slice_dims = [slice(*x) for x in slice_dims]
        slice_value = ckpt_value[slice_dims]
        slice_value = slice_value.reshape(var._save_slice_info.var_shape)
        feed_dict[placeholder_tensor] = slice_value

  assign_op = control_flow_ops.group(*assign_ops)
  return assign_op, feed_dict
# pylint: enable=protected-access

24. assign_from_checkpoint_fn

def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False,
                              reshape_variables=False):
  """
  返回一個函式,該函式從斷點檔案分配指定變數.
  如果 ignore_missing_vars=True,或斷點檔案中沒有變數,則返回 None.

  Args:
    model_path: The full path to the model checkpoint. 
                        To get latest checkpoint
                        use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
    var_list: A list of `Variable` objects or a dictionary mapping names in the
                checkpoint to the corresponding variables to initialize. If empty or
                `None`, it would return `no_op(), None`.
    ignore_missing_vars: Boolean,如果為 True,則忽略在斷點中缺失的變數,
                並給出 warning,而不是出錯.
    reshape_variables: Boolean, if True it would automatically reshape variables
        which are of different shape then the ones stored in the checkpoint but
        which have the same number of elements.
  Returns:
    A function that takes a single argument, a `tf.Session`, that applies the
    assignment operation. If no matching variables were found in the checkpoint
    then `None` is returned.
  Raises:
    ValueError: If var_list is empty.
  """
  if not var_list:
    raise ValueError('var_list cannot be empty')
  if ignore_missing_vars:
    reader = pywrap_tensorflow.NewCheckpointReader(model_path)
    if isinstance(var_list, dict):
      var_dict = var_list
    else:
      var_dict = {var.op.name: var for var in var_list}
    available_vars = {}
    for var in var_dict:
      if reader.has_tensor(var):
        available_vars[var] = var_dict[var]
      else:
        logging.warning(
            'Variable %s missing in checkpoint %s', var, model_path)
    var_list = available_vars
  if var_list:
    saver = tf_saver.Saver(var_list, reshape=reshape_variables,
                           write_version=saver_pb2.SaverDef.V1)
    def callback(session):
      saver.restore(session, model_path)
    return callback
  else:
    logging.warning('No Variables to restore')
    return None

25. VariableDeviceChooser

class VariableDeviceChooser(object):
  """
  變數的裝置選擇器. Device chooser for variables.
  當使用引數伺服器時,則以輪流的方式分配.round-robin fashion.
  當未使用引數伺服器時,可以設定 GPU 或 CPU.
  """

  def __init__(self, num_tasks=0, job_name='ps', device_type='CPU',
               device_index=0, replica=None):
    """
    初始化 VariableDeviceChooser.

    Usage:
         使用 2 個引數伺服器:VariableDeviceChooser(2)
         不使用引數伺服器:VariableDeviceChooser()
                                          VariableDeviceChooser(device_type='GPU') # GPU
    Args:
      num_tasks: number of tasks.
      job_name: String, a name for the parameter server job.
      device_type: Optional device type string (e.g. "CPU" or "GPU")
      device_index: int.  Optional device index.  If left unspecified, device represents 'any' device_index.
    """
    self._job_name = job_name
    self._device_type = device_type
    self._device_index = device_index
    self._replica = replica
    self._num_tasks = num_tasks
    self._next_task_id = 0

  def __call__(self, op):
    device_spec = tf_device.DeviceSpec(eplica=self._replica,
                                       device_type=self._device_type,
                                       device_index=self._device_index)
    if self._num_tasks > 0:
      task_id = self._next_task_id
      self._next_task_id = (self._next_task_id + 1) % self._num_tasks
      device_spec.job = self._job_name
      device_spec.task = task_id
    return device_spec.to_string()

26. filter_variables

def filter_variables(var_list, include_patterns=None, exclude_patterns=None,
                     reg_search=True):
  """
  採用正則表示式過濾變數列表.
  首先,根據 include_patterns 列表,來列入(include)變數.
  然後,根據 exclude_patterns 列表,來排除(exclude)變數.

  例如,可以採用如下方式獲取所有卷積層的權重變數列表(取決於網路定義):
  variables = tf.contrib.framework.get_model_variables()
  conv_weight_variables = tf.contrib.framework.filter_variables(
      variables,
      include_patterns=['Conv'],
      exclude_patterns=['biases', 'Logits'])


  Args:
    var_list: 變數列表.
    include_patterns: include 的正則表示式列表.
                                預設為 None, 表示根據 include  規則選擇所有的變數.
                                如果變數與 include _patterns 中任何一個相匹配,則列入該變數.
    exclude_patterns: excluede 的正則表示式列表.
                                預設為 None, 表示根據 exclude 規則選擇所有的變數.
                                如果變數與 exclude _patterns 中任何一個相匹配,則排除該變數.
    reg_search: boolean. 預設為 True, 採用 re.search 操作,查詢匹配項
                                (i.e. pattern 可以匹配變數名字的任何子字串). 
                                如果為 False,採用 re.match 操作. 
                                (i.e. regexp 從變數名的首部進行匹配).
   Returns:
     過濾的變數列表.
  """
  if reg_search:
    reg_exp_func = re.search
  else:
    reg_exp_func = re.match

  # First include variables.
  if include_patterns is None:
    included_variables = list(var_list)
  else:
    included_variables = []
    for var in var_list:
      if any(reg_exp_func(ptrn, var.name) for ptrn in include_patterns):
        included_variables.append(var)

  # Afterwards, exclude variables.
  if exclude_patterns is None:
    filtered_variables = included_variables
  else:
    filtered_variables = []
    for var in included_variables:
      if not any(reg_exp_func(ptrn, var.name) for ptrn in exclude_patterns):
        filtered_variables.append(var)

  return filtered_variables