TensorFlow中張量轉置操作tf.transpose用法詳解
阿新 • • 發佈:2018-12-13
一、環境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
Python 3.6.3
二、官方說明
對張量按照指定的排列維度進行轉置
tf.transpose(
a,
perm=None,
name='transpose',
conjugate=False
)
輸入:
(1)a:輸入張量
(2)perm:輸入張量要進行轉置操作的維度的排列方式
(3)name:可選引數,轉置操作的名稱
(4)conjugate:可選引數,布林型別,如果設定為True,則數學意義上等同於tf.conj(tf.transpose(input))
輸出:
(1)按照指定維度排列方式轉置後的張量
三、例項
(1)不設定perm引數值時,perm預設為(n-1, n-2, ..., 2, 1, 0),其中n為輸入張量的階(rank)
>>> x = tf.constant([[1,2,3],[4,5,6]]) >>> with tf.Session() as sess: ... print(sess.run(tf.transpose(x))) ... print(sess.run(tf.shape(tf.transpose(x)))) ... [[1 4] [2 5] [3 6]] ... [3 2]
(2)上例中等同例項(張量x的階為2,因此perm預設為(2-1,0))
>>> x = tf.constant([[1,2,3],[4,5,6]])
>>> with tf.Session() as sess:
... print(sess.run(tf.transpose(x,[1,0])))
... print(sess.run(tf.shape(tf.transpose(x))))
...
[[1 4]
[2 5]
[3 6]]
...
[3 2]
(3)輸入張量為複數的情況,引數conjugate=True時,進行共軛轉置操作
>>> real = [[1.0,2.0,3.0],[4.0,5.0,6.0]]
>>> imag = [[1.0,2.0,3.0],[4.0,5.0,6.0]]
>>> complex = tf.complex(real,imag)
>>> with tf.Session() as sess:
... print(sess.run(complex))
... print(sess.run(tf.shape(complex)))
... print(sess.run(tf.transpose(complex)))
... print(sess.run(tf.shape(tf.transpose(complex))))
... print(sess.run(tf.transpose(complex,conjugate=True)))
... print(sess.run(tf.shape(tf.transpose(complex,conjugate=True))))
...
[[1.+1.j 2.+2.j 3.+3.j]
[4.+4.j 5.+5.j 6.+6.j]]
...
[2 3]
...
[[1.+1.j 4.+4.j]
[2.+2.j 5.+5.j]
[3.+3.j 6.+6.j]]
...
[3 2]
...
[[1.-1.j 4.-4.j]
[2.-2.j 5.-5.j]
[3.-3.j 6.-6.j]]
...
[3 2]
...
(4)輸入張量的維度大於2時,引數perm起作用更大
直觀來講,這裡的引數perm=[0,2,1],控制將原來的維度[0,1,2]後面兩列置換位置
>>> x = tf.constant([[[ 1, 2, 3],
... [ 4, 5, 6]],
... [[ 7, 8, 9],
... [10, 11, 12]]])
>>> with tf.Session() as sess:
... print(sess.run(tf.transpose(x,[0,2,1])))
...
[[[ 1 4]
[ 2 5]
[ 3 6]]
[[ 7 10]
[ 8 11]
[ 9 12]]]