1. 程式人生 > >Tensorflow(1)進行多維矩陣的拆分與拼接

Tensorflow(1)進行多維矩陣的拆分與拼接

最近在使用tensorflow進行網路訓練的時候,需要提取出別人訓練好的卷積核的部分層的資料。由於tensorflow中的tensor和python中的list不同,無法直接使用加法進行拼接,後來發現一個函式可以完成tensor的拼接。
函式形式如下:

 tf.concat(concat_dim,values,name='concat')

其中,第一個引數表示需要拼接的多維tensor,並且可以將多個tensor同事拼接,第二個表示按照哪一個維度拼接(從數字0開始)。
例子:建立一個三維的tensor,然後分別取出最後一個維度(注意:tensor支援與python中list相似的切片操作,可以使用這種方式進行拆分),然後在拼接在一起。

import tensorflow as tf

weights=tf.Variable(tf.truncated_normal([2,3,4],dtype=tf.float32,stddev=1e-1),name='weights')

weight1=weights[0:2,0:3,1:2]
weight2=weights[0:2,0:3,2:3]
weight3=weights[0:2,0:3,1:2]
weight4=tf.concat([weight1,weight2,weight3],2) #2表示最後一個維度

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(weights))
	print("****************")
	print(sess.run(weight4))

在這裡插入圖片描述