1. 程式人生 > >tf.shape, x.shape, x.get_shape的區別

tf.shape, x.shape, x.get_shape的區別

tf.shape, x.shape, x.get_shape的區別

tf.shape(x)

回傳的型別為tensorflow.python.framework.ops.Tensor,因此可以在計算圖中使用。在使用sess.run或x.eval()後可以得到x的具體形狀(不會有?的存在)。

x.shape, x.get_shape()

可以將這兩種寫法想成是一樣的功能,回傳的型別皆為tensorflow.python.framework.tensor_shape.TensorShape

無法在計算圖中使用,意即無法使用sess.run()或x.eval()來得到確切的值。

使用示例

下面展示了這三種用法的不同之處:
x.shape及x.get_shape()可以單獨使用,得到tensor x在定義時的形狀。
tf.shape則可以被放入計算圖中,得到張量在執行後真正的形狀。同時它也可以被放入其他的tf運算內(此處將tf.shape放到tf.reshape內),成為計算圖的一部份。

// An highlighted block
x1 = np.arange(32).reshape(2,16)
x2 = np.arange(32).reshape(4,8)
a = tf.
placeholder(shape=(None, 16), dtype=tf.float32) b = tf.placeholder(shape=(None, 8), dtype=tf.float32) c = tf.reshape(b, tf.shape(a)) # c = tf.reshape(b, a.shape) #not work # c = tf.reshape(b, a.get_shape()) #not work print(a.shape) #(?, 8) print(a.get_shape()) #(?, 8) with tf.Session() as sess: print
(sess.run(tf.shape(a), feed_dict={a: x1})) #[ 2 16] print(sess.run(a, feed_dict={a: x1})) """ result: [[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.] [16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]] """ print(sess.run(b, feed_dict={b: x2})) """ result: [[ 0. 1. 2. 3. 4. 5. 6. 7.] [ 8. 9. 10. 11. 12. 13. 14. 15.] [16. 17. 18. 19. 20. 21. 22. 23.] [24. 25. 26. 27. 28. 29. 30. 31.]] """ print(sess.run(c, feed_dict={a: x1, b: x2})) """ result: [[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.] [16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]] """

參考連結

[1]https://stackoverflow.com/questions/37085430/tf-shape-get-wrong-shape-in-tensorflow