1. 程式人生 > >陣列拼接tf.concat()和np.concatenate()的區別

陣列拼接tf.concat()和np.concatenate()的區別

陣列拼接tf.concat()和np.concatenate()的區別


Tensorflow新手,在程式裡要將兩個陣列進行拼接,使用了tf.concat()函式,然而執行過程中出現瞭如下錯誤:

TypeError: Tensors in list passed to ‘values’ of ‘ConcatV2’ Op have types [float64, < NOT CONVERTIBLE TO TENSOR>] that don’t all match.

具體程式碼如下:

All_Patches1 = scipy.io.loadmat(path + '/data/All_Patches1.mat')['All_Patches1']
All_Patches2 = scipy.io.loadmat(path + '/data/All_Patches2.mat')['All_Patches2']
All_Labels = scipy.io.loadmat(path + '/data/All_Labels.mat')['All_Labels']

All_Patches = tf.concat([All_Patches1,All_Patches2],axis = 0)

在查了一系列相關類似錯誤後始終沒有把自己bug解決,只得向師兄尋求援助,師兄只瞅了一眼就發現了我幾個小時咩有解決的問題,原來癥結就在對tensorflow計算圖上理解不夠。
tensorflow所有的計算都是計算圖上的一個節點,定義好的運算是通過會話(session)來執行的,使用tf.concat()函式後也需要在session中執行,但是我只是對兩個常量陣列進行拼接,此時正確的函式應該是np.concatenate(),它可以直接使用,即如下程式碼:

All_Patches = np.concatenate((All_Patches1,All_Patches2),axis = 0)

這個問題過後稍稍瞭解了為什麼tensorflow和numpy中均定義了類似功能的函式。正確使用tf.concat()函式的一個示範如下:

x_labeled = tf.placeholder(tf.float32, [batch_size_labeled, patch_size, patch_size, num_band])
x_unlabeled = tf.placeholder(tf.float32, [batch_size_unlabeled, patch_size, patch_size, num_band])
x_input = tf.concat([x_labeled, x_unlabeled], axis=0)