tensorflow 之 tf.shape() 和 tf.get_shape() 及 tf.concat()
阿新 • • 發佈:2018-12-20
1、用法:tf.shape( input, out_type, name=None )
- input:可以是tensor,list,arrray。
- out_type:可選,tf.int32 和 tf.int64;預設 tf.int32。
- name:操作的名稱(可選)
返回:型別為 out_type 的 input維度,是一個 tensor。其實就是獲取輸入的維度
import tensorflow as tf t1 = [[[1,2], [3,4]], [[5,6], [7,8]]] t1_shape1 = tf.shape(t1) print(t1_shape1) #返回:Tensor("Shape_8:0", shape=(3,), dtype=int32) t1_shape = t1_shape1.get_shape() print(t1_shape) #返回:(3,) t1_shape = t1_shape1.get_shape().as_list() print(t1_shape) #返回:[3]
2、用法:tf.get_shape(x)
返回:x 維度的元組,輸入 x 只能是tensor
通常為了對元組處理,必須將元組轉換成 list。在神經網路中使用它,我猜是保險起見~防止經過一系列處理之後輸入型別改變,使用該函式,輸入只能是tensor。
3、tf.concat([x1, x2], axis=0/1/2/3....)
將 x1 和 x2 的維度拼接,如何拼接取決於 axis,axis=0 表示在 x1 和 x2的第一維度拼接,1 表示在第二維度拼接,我遇到的是在第4維度拼接,應該只是取決於你輸入的維度大小,如果維度是2,想要在3,4維度拼接肯定是報錯的。
參考:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-smjo2k45.html
https://blog.csdn.net/fireflychh/article/details/73611021