1. 程式人生 > >TensorFlow中張量轉置操作tf.transpose用法詳解

TensorFlow中張量轉置操作tf.transpose用法詳解

一、環境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

Python 3.6.3

二、官方說明

對張量按照指定的排列維度進行轉置

tf.transpose(
    a,
    perm=None,
    name='transpose',
    conjugate=False
)

輸入:

(1)a:輸入張量

(2)perm:輸入張量要進行轉置操作的維度的排列方式

(3)name:可選引數,轉置操作的名稱

(4)conjugate:可選引數,布林型別,如果設定為True,則數學意義上等同於tf.conj(tf.transpose(input))

輸出:

(1)按照指定維度排列方式轉置後的張量

 

三、例項

(1)不設定perm引數值時,perm預設為(n-1, n-2, ..., 2, 1, 0),其中n為輸入張量的階(rank)

>>> x = tf.constant([[1,2,3],[4,5,6]])
>>> with tf.Session() as sess:
...     print(sess.run(tf.transpose(x)))
...     print(sess.run(tf.shape(tf.transpose(x))))
... 
[[1 4]
 [2 5]
 [3 6]]
... 
[3 2]

(2)上例中等同例項(張量x的階為2,因此perm預設為(2-1,0))

>>> x = tf.constant([[1,2,3],[4,5,6]])
>>> with tf.Session() as sess:
...     print(sess.run(tf.transpose(x,[1,0])))
...     print(sess.run(tf.shape(tf.transpose(x))))
... 
[[1 4]
 [2 5]
 [3 6]]
... 
[3 2]

(3)輸入張量為複數的情況,引數conjugate=True時,進行共軛轉置操作

>>> real = [[1.0,2.0,3.0],[4.0,5.0,6.0]]
>>> imag = [[1.0,2.0,3.0],[4.0,5.0,6.0]]
>>> complex = tf.complex(real,imag)
>>> with tf.Session() as sess:
...     print(sess.run(complex))
...     print(sess.run(tf.shape(complex)))
...     print(sess.run(tf.transpose(complex)))
...     print(sess.run(tf.shape(tf.transpose(complex))))
...     print(sess.run(tf.transpose(complex,conjugate=True)))
...     print(sess.run(tf.shape(tf.transpose(complex,conjugate=True))))
... 
[[1.+1.j 2.+2.j 3.+3.j]
 [4.+4.j 5.+5.j 6.+6.j]]
... 
[2 3]
... 
[[1.+1.j 4.+4.j]
 [2.+2.j 5.+5.j]
 [3.+3.j 6.+6.j]]
... 
[3 2]
... 
[[1.-1.j 4.-4.j]
 [2.-2.j 5.-5.j]
 [3.-3.j 6.-6.j]]
... 
[3 2]
... 

(4)輸入張量的維度大於2時,引數perm起作用更大

直觀來講,這裡的引數perm=[0,2,1],控制將原來的維度[0,1,2]後面兩列置換位置

>>> x = tf.constant([[[ 1,  2,  3],
...                   [ 4,  5,  6]],
...                  [[ 7,  8,  9],
...                   [10, 11, 12]]])
>>> with tf.Session() as sess:
...     print(sess.run(tf.transpose(x,[0,2,1])))
... 
[[[ 1  4]
  [ 2  5]
  [ 3  6]]

 [[ 7 10]
  [ 8 11]
  [ 9 12]]]