1. 程式人生 > >Tensorflow - tf.tile() 學習

Tensorflow - tf.tile() 學習

 API:https://tensorflow.google.cn/api_docs/python/tf/tile?hl=zh-cn


tf.tile()用於張量擴充套件

tf.tile(
    input,
    multiples,
    name=None
)

輸入是一個Tensor

multiples的維度與輸入的維度相一致,並標明在哪一個維度上進行擴充套件,擴充套件的方法就是複製為相同的元素,下面的例子可以說明問題:

import tensorflow as tf

raw = tf.Variable(tf.random_normal(shape=(2 ,2, 2)))
multi1 = tf.tile(raw, multiples=[2, 1, 1])
multi2 = tf.tile(raw, multiples=[1, 2, 1])
multi3 = tf.tile(raw, multiples=[1, 1, 2])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(raw.eval())
    print('-----------------------------')
    a = sess.run(multi1)
    b = sess.run(multi2)
    c = sess.run(multi3)
    print(a)
    print(a.shape)
    print('-----------------------------')
    print(b)
    print(b.shape)
    print('-----------------------------')
    print(c)
    print(c.shape)
    print('-----------------------------')
#原始
[[[ 0.6948325  -0.16302951]
  [-0.60185844  0.3866387 ]]

 [[-0.5528875  -0.06845065]
  [ 0.24240932  0.72961247]]]
-----------------------------
# multiples=[2, 1, 1]
[[[ 0.6948325  -0.16302951]
  [-0.60185844  0.3866387 ]]

 [[-0.5528875  -0.06845065]
  [ 0.24240932  0.72961247]]

 [[ 0.6948325  -0.16302951]
  [-0.60185844  0.3866387 ]]

 [[-0.5528875  -0.06845065]
  [ 0.24240932  0.72961247]]]
(4, 2, 2)
-----------------------------
# multiples=[1, 2, 1]
[[[ 0.6948325  -0.16302951]
  [-0.60185844  0.3866387 ]
  [ 0.6948325  -0.16302951]
  [-0.60185844  0.3866387 ]]

 [[-0.5528875  -0.06845065]
  [ 0.24240932  0.72961247]
  [-0.5528875  -0.06845065]
  [ 0.24240932  0.72961247]]]
(2, 4, 2)
-----------------------------
# multiples=[1, 1, 2]
[[[ 0.6948325  -0.16302951  0.6948325  -0.16302951]
  [-0.60185844  0.3866387  -0.60185844  0.3866387 ]]

 [[-0.5528875  -0.06845065 -0.5528875  -0.06845065]
  [ 0.24240932  0.72961247  0.24240932  0.72961247]]]
(2, 2, 4)
-----------------------------