TensorFlow中張量轉置操作tf.expand_dims用法
阿新 • • 發佈:2018-12-17
一、環境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
二、官方說明
給輸入張量的形狀增加1個維度
https://www.tensorflow.org/api_docs/python/tf/expand_dims
tf.expand_dims(
input,
axis=None,
name=None,
dim=None
)
輸入:
(1)input:輸入張量
(2)axis:標量,指定在哪個維度上給輸入張量增加一個維度,範圍必須在 [-輸入張量的秩,+輸入張量的秩]
(3)name:輸出結果張量的名稱
(4)dim:標量,等同於axis,被棄用
返回結果:
(1)比輸入張量多1維但是包含相同資料的張量
三、例項
(1)tf.expand_dims(input_tensor,0)
>>> import tensorflow as tf >>> input = tf.constant([1,2], shape=[2]) >>> sess = tf.Session() >>> sess.run(tf.shape(input)) array([2]) >>> sess.run(tf.shape(tf.expand_dims(input,0))) array([1, 2]) >>> sess.run(tf.expand_dims(input,0)) array([[1, 2]]) >>> sess.close()
(2)tf.expand_dims(input_tensor,1)
>>> import tensorflow as tf input = tf.constant([1,2], shape=[2]) >>> sess = tf.Session() >>> sess.run(tf.shape(input)) array([1]) >>> sess.run(tf.shape(tf.expand_dims(input,1))) array([2, 1]) >>> sess.run(tf.expand_dims(input,1)) array([[1], [2]])
(3)tf.expand_dims(input_tensor,-1)
>>> import tensorflow as tf
>>> input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([1])
>>> sess.run(tf.shape(tf.expand_dims(input,-1)))
array([2, 1])
>>> sess.run(tf.expand_dims(input,-1))
array([[1],
[2]])
(4)多維拓展0
>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,0)))
array([1, 2, 3, 5])
>>> sess.run(tf.expand_dims(input,0))
array([[[[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]],
[[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25],
[26, 27, 28, 29, 30]]]])
>>> sess.close()
(5)多維拓展2
>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,2)))
array([2, 3, 1, 5])
>>> sess.run(tf.expand_dims(input,2))
array([[[[ 1, 2, 3, 4, 5]],
[[ 6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15]]],
[[[16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25]],
[[26, 27, 28, 29, 30]]]])
(6)多維拓展3
>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,3)))
array([2, 3, 5, 1])
>>> sess.run(tf.expand_dims(input,3))
array([[[[ 1],
[ 2],
[ 3],
[ 4],
[ 5]],
[[ 6],
[ 7],
[ 8],
[ 9],
[10]],
[[11],
[12],
[13],
[14],
[15]]],
[[[16],
[17],
[18],
[19],
[20]],
[[21],
[22],
[23],
[24],
[25]],
[[26],
[27],
[28],
[29],
[30]]]])