1. 程式人生 > >tf.shape() 和x.get_shape().as_list() 和tf.split()

tf.shape() 和x.get_shape().as_list() 和tf.split()

1、tf.shape(A) # 獲取張量A(陣列,list, tensor張量)的大小,返回的是一個list。

import tensorflow as tf
import numpy as np

a_array=np.array([[1,2,3],[4,5,6]])
b_list=[[1,2,3],[3,4,5]]
c_tensor=tf.constant([[1,2,3],[4,5,6]])

with tf.Session() as sess:
    print(sess.run(tf.shape(a_array)))
    print(sess.run(tf.shape(b_list)))
    print(sess.run(tf.shape(c_tensor)))

返回[2, 3],[2, 3],[2, 3]。

2、x.get_shape().as_list()

x.get_shape(),只有tensor才可以使用這種方法,返回的是一個元組

import tensorflow as tf
import numpy as np

a_array=np.array([[1,2,3],[4,5,6]])
b_list=[[1,2,3],[3,4,5]]
c_tensor=tf.constant([[1,2,3],[4,5,6]])

print(c_tensor.get_shape())
print(c_tensor.get_shape().as_list())

with tf.Session() as sess:
    print(sess.run(tf.shape(a_array)))
    print(sess.run(tf.shape(b_list)))
    print(sess.run(tf.shape(c_tensor)))

返回:(2, 3),[2, 3],  [2, 3], [2, 3],  [2, 3]

只能用於tensor來返回shape,但是是一個元組,需要通過as_list()的操作轉換成list.

3、tf.split()

tf.split(dimension, num_split, input):dimension的意思就是輸入張量的哪一個維度,如果是0就表示對第0維度進行切割。num_split就是切割的數量,如果是2就表示輸入張量被切成2份,每一份是一個列表。

import tensorflow as tf
import numpy as np

A = [[1,2,3],[4,5,6]]
x = tf.split(1, 3, A)

with tf.Session() as sess:
	c = sess.run(x)
	for ele in c:
		print ele

輸出: 

[[1],
  [4]]
 [[2],
  [5]]
 [[3],
  [6]]