1. 程式人生 > >TensorFlow中張量連線操作tf.concat用法詳解

TensorFlow中張量連線操作tf.concat用法詳解

一、環境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

Python 3.6.3

二、官方說明

按指定軸(axis)進行張量連線操作(Concatenates Tensors)

tf.concat(
    values,
    axis,
    name='concat'
)

輸入:

(1)values:多個張量組成的列表或者單個張量

(2)axis:0維整形張量(整數),定義按照按個數據軸進行張量連線操作,其範圍是[-輸入張量的階,+輸入張量的階]。[0,輸入張量的階]範圍內的正數表示按照指定的axis軸進行連線操作,在[-輸入張量的階,0]之間的負數表示按照指定的(axis + 輸入張量的階)的軸進行連線操作

(3)name:可選引數,定義該張量連線操作的名稱

輸出:

輸入張量按照指定軸連線後的一個結果張量

 

三、例項

(1)單個張量作為輸入

>>> t1 = [[1,2,3],[4,5,6]]
>>> con1 = tf.concat(t1,0)
>>> shape1 = tf.shape(con1)
>>> with tf.Session() as sess:
...     print(sess.run(con1))
...     print(sess.run(shape1))
... 
[1 2 3 4 5 6]
[6]

(2)多個張量組成的列表作為輸入

按照0軸(行)進行連線:

>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con2 = tf.concat([t1,t2],0)
>>> shape2 = tf.shape(con2)
>>> with tf.Session() as sess:
...     print(sess.run(con2))
...     print(sess.run(shape2))
... 
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]
[4 3]

按照1軸(列)進行連線:

>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con3 = tf.concat([t1,t2],1)
>>> shape3 = tf.shape(con3)
>>> with tf.Session() as sess:
...     print(sess.run(con3))
...     print(sess.run(shape3))
... 
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]
[2 6]
>>> 

按照-1軸(列)進行連線:

>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con4 = tf.concat([t1,t2],-1)
>>> shape4 = tf.shape(con4)
>>> with tf.Session() as sess:
...     print(sess.run(con4))
...     print(sess.run(shape4))
... 
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]
[2 6]

 

注意:如果想沿著一個新軸連線張量,則考慮使用stcak

不建議使用:tf.concat([tf.expand_dims(t, axis) for t in tensors],axis)

推薦使用:tf.stack(tensors, axis=axis)