1. 程式人生 > >tensorflow API:tf.map_fn

tensorflow API:tf.map_fn

tf.map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

作用:map on the list of tensors unpacked from elems on dimension 0.

引數
fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as elems

. Its output must have the same structure as dtype if one is provided, otherwise it must have the same structure as elems.

elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to fn

.

dtype: (optional) The output type(s) of fn. If fn returns a structure of Tensors differing from the structure of elems, then dtype is not optional and must have the same structure as the output of fn.

parallel_iterations: (optional) The number of iterations allowed to run in parallel.

back_prop

: (optional) True enables support for back propagation.

swap_memory: (optional) True enables GPU-CPU memory swapping.

infer_shape: (optional) False disables tests for consistent output shapes.

name: (optional) Name prefix for the returned tensors.

官網例子:
1.

elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]

elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]

elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]