1. 程式人生 > >tf.py_func()函式

tf.py_func()函式

tensorflow由於構建的是靜態圖,所以導致在tf.Session().run()之前是沒有實際值的,因此,在網路搭建的時候,是不能對tensor進行判值操作的,即不能插入if…else…之類的程式碼。第二,相較於numpy array,Tensorflow中對tensor的操作介面靈活性並沒有那麼高,使得Tensorflow的靈活性減弱。

在筆者使用Tensorflow的一年中積累的程式設計經驗來看,擴充套件Tensorflow程式的靈活性,有一個重要的手段,就是使用tf.py_func介面。 介面解析

程式碼測試:

def my_func(array1,array2)
: return array1 + array2, array1 - array2 if __name__ =='__main__': array1 = np.array([[1, 2], [3, 4]]) array2 = np.array([[1, 2], [3, 4]]) a1 = tf.placeholder(tf.float32,[2,2],name = 'array1') a2 = tf.placeholder(tf.float32,[2,2],name = 'array2') y1,y2 = tf.py_func(my_func,
[a1,a2],[tf.float32, tf.float32]) with tf.Session() as sess: y1_,y2_ = sess.run([y1,y2],feed_dict={a1:array1,a2:array2}) print(y1_) print('*'*10) print(y2_)

從上面的程式碼我們可以看出,tf.py_func()接收的是tensor,然後將其轉化為numpy array送入我們自定義的my_func函式,最後再將my_func函式輸出的numpy array轉化為tensor返回

如果不用tf.py_func()實現的話,我們還可以這樣直接用array的方式操作:

def my_func(array1,array2):
    return array1 + array2, array1 - array2

with tf.Session() as sess:
    array1 = np.array([[1, 2], [3, 4]])
    array2 = np.array([[1, 2], [3, 4]])
    y1,y2 = my_func(array1,array2)
    print(y1)
    print('*' * 10)
    print(y2)