1. 程式人生 > >tensorflow 之 tf.shape() 和 tf.get_shape() 及 tf.concat()

tensorflow 之 tf.shape() 和 tf.get_shape() 及 tf.concat()

1、用法:tf.shape( input, out_type, name=None )

  • input:可以是tensor,list,arrray。
  • out_type:可選,tf.int32 和 tf.int64;預設 tf.int32。
  • name:操作的名稱(可選)

返回:型別為 out_type 的 input維度,是一個 tensor。其實就是獲取輸入的維度

import tensorflow as tf

t1 = [[[1,2], [3,4]],
      [[5,6], [7,8]]]
       
t1_shape1 = tf.shape(t1)
print(t1_shape1)
#返回:Tensor("Shape_8:0", shape=(3,), dtype=int32)

t1_shape = t1_shape1.get_shape()
print(t1_shape)
#返回:(3,)

t1_shape = t1_shape1.get_shape().as_list()
print(t1_shape)
#返回:[3]

2、用法:tf.get_shape(x)

返回:x 維度的元組,輸入 x 只能是tensor

通常為了對元組處理,必須將元組轉換成 list。在神經網路中使用它,我猜是保險起見~防止經過一系列處理之後輸入型別改變,使用該函式,輸入只能是tensor。

3、tf.concat([x1, x2], axis=0/1/2/3....)

將 x1 和 x2 的維度拼接,如何拼接取決於 axis,axis=0 表示在 x1 和 x2的第一維度拼接,1 表示在第二維度拼接,我遇到的是在第4維度拼接,應該只是取決於你輸入的維度大小,如果維度是2,想要在3,4維度拼接肯定是報錯的。

 

 

 

 

 

 

 

 

 

 

參考:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-smjo2k45.html

https://blog.csdn.net/fireflychh/article/details/73611021