1. 程式人生 > >TensorFlow中張量轉置操作tf.expand_dims用法

TensorFlow中張量轉置操作tf.expand_dims用法

一、環境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

二、官方說明

給輸入張量的形狀增加1個維度

https://www.tensorflow.org/api_docs/python/tf/expand_dims

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

輸入:

(1)input:輸入張量

(2)axis:標量,指定在哪個維度上給輸入張量增加一個維度,範圍必須在 [-輸入張量的秩,+輸入張量的秩]

(3)name:輸出結果張量的名稱

(4)dim:標量,等同於axis,被棄用

 

返回結果:

(1)比輸入張量多1維但是包含相同資料的張量

 

三、例項

(1)tf.expand_dims(input_tensor,0)

>>> import tensorflow as tf
>>> input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2])
>>> sess.run(tf.shape(tf.expand_dims(input,0)))
array([1, 2])
>>> sess.run(tf.expand_dims(input,0))
array([[1, 2]])
>>> sess.close()

(2)tf.expand_dims(input_tensor,1)

>>> import tensorflow as tf
input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([1])
>>> sess.run(tf.shape(tf.expand_dims(input,1)))
array([2, 1])
>>> sess.run(tf.expand_dims(input,1))
array([[1],
       [2]])

(3)tf.expand_dims(input_tensor,-1)

>>> import tensorflow as tf
>>> input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([1])
>>> sess.run(tf.shape(tf.expand_dims(input,-1)))
array([2, 1])
>>> sess.run(tf.expand_dims(input,-1))
array([[1],
       [2]])

(4)多維拓展0

>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,0)))
array([1, 2, 3, 5])
>>> sess.run(tf.expand_dims(input,0))
array([[[[ 1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10],
         [11, 12, 13, 14, 15]],

        [[16, 17, 18, 19, 20],
         [21, 22, 23, 24, 25],
         [26, 27, 28, 29, 30]]]])
>>> sess.close()

(5)多維拓展2

>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,2)))
array([2, 3, 1, 5])
>>> sess.run(tf.expand_dims(input,2))
array([[[[ 1,  2,  3,  4,  5]],

        [[ 6,  7,  8,  9, 10]],

        [[11, 12, 13, 14, 15]]],


       [[[16, 17, 18, 19, 20]],

        [[21, 22, 23, 24, 25]],

        [[26, 27, 28, 29, 30]]]])

(6)多維拓展3

>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,3)))
array([2, 3, 5, 1])
>>> sess.run(tf.expand_dims(input,3))
array([[[[ 1],
         [ 2],
         [ 3],
         [ 4],
         [ 5]],

        [[ 6],
         [ 7],
         [ 8],
         [ 9],
         [10]],

        [[11],
         [12],
         [13],
         [14],
         [15]]],


       [[[16],
         [17],
         [18],
         [19],
         [20]],

        [[21],
         [22],
         [23],
         [24],
         [25]],

        [[26],
         [27],
         [28],
         [29],
         [30]]]])