tf.shape, x.shape, x.get_shape的區別

tf.shape(x)

回傳的型別為tensorflow.python.framework.ops.Tensor,因此可以在計算圖中使用。在使用sess.run或x.eval()後可以得到x的具體形狀(不會有?的存在)。

x.shape, x.get_shape()

可以將這兩種寫法想成是一樣的功能,回傳的型別皆為tensorflow.python.framework.tensor_shape.TensorShape無法在計算圖中使用,意即無法使用sess.run()或x.eval()來得到確切的值。

使用示例

下面展示了這三種用法的不同之處:
x.shape及x.get_shape()可以單獨使用,得到tensor x在定義時的形狀。
tf.shape則可以被放入計算圖中,得到張量在執行後真正的形狀。同時它也可以被放入其他的tf運算內(此處將tf.shape放到tf.reshape內),成為計算圖的一部份。

// An highlighted block
x1 = np.arange(32).reshape(2,16)
x2 = np.arange(32).reshape(4,8)
a = tf.placeholder(shape=(None, 16), dtype=tf.float32)
b = tf.placeholder(shape=(None, 8), dtype=tf.float32)
c = tf.reshape(b, tf.shape(a))
# c = tf.reshape(b, a.shape) #not work
# c = tf.reshape(b, a.get_shape()) #not work
print(a.shape) #(?, 8)
print(a.get_shape()) #(?, 8)
with tf.Session() as sess:
	print(sess.run(tf.shape(a), feed_dict={a: x1})) #[ 2 16]
    print(sess.run(a, feed_dict={a: x1}))
    """
    result:
    [[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15.]
 	[16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]]
    """
    print(sess.run(b, feed_dict={b: x2}))
    """
    result:
    [[ 0.  1.  2.  3.  4.  5.  6.  7.]
	 [ 8.  9. 10. 11. 12. 13. 14. 15.]
	 [16. 17. 18. 19. 20. 21. 22. 23.]
	 [24. 25. 26. 27. 28. 29. 30. 31.]]
    """
    print(sess.run(c, feed_dict={a: x1, b: x2}))
    """
    result:
    [[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15.]
 	[16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]]
    """

參考連結

[1]https://stackoverflow.com/questions/37085430/tf-shape-get-wrong-shape-in-tensorflow