1. 程式人生 > >不懂錘爆我係列之Tensorflow入門學習—— 張量拓展函式tile()詳解

不懂錘爆我係列之Tensorflow入門學習—— 張量拓展函式tile()詳解

第二期,第二期,開始,開始。

在tensorflow中有個很常用的張量擴充套件函式——tile(),看過了許多講解部落格之後,覺得有必要系統的進行一下整理。同時,我將講解一維、二維、乃至多維張量使用tile()的運算過程與規則。

下面,我們還是以一段程式碼為例:

import tensorflow as tf
#下面兩行是為了拋掉錯誤資訊的提示,大家忽略即可
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

#一維張量, 即shape = [X, ]的張量
a = tf.tile([1, 2, 3],[3])

#二維張量, 即shape = [X, Y]的張量
b = tf.tile( [[1, 2], [2, 3], [3, 4]] ,[3, 2])

#三維張量, 即shape = [X, Y, Z]的張量
temp = tf.Variable(tf.random_normal(shape=(1, 3, 2)))

c = tf.tile(temp, [2, 1, 1])
d = tf.tile(temp, [2, 2, 1])
e = tf.tile(temp, [2, 2, 2])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))
    print(sess.run(b))
    print("---------------------------")
    print(sess.run(temp))
    print("---------------------------")
    print(sess.run(c))
    print(sess.run(d))
    print(sess.run(e))

上文的程式碼,大家可以根據註釋按順序看,下面我的講解也將按照a到e的順序進行。

  • 首先我們要明白tf.tile( , )函式中的兩個引數,第一個引數就是傳入的張量,第二個引數則是擴充套件的倍數

       明白了函式內容後我們看到 a = tf.tile([1, 2, 3], [3]) , a作為一維張量,後面的3賦予a的意義為對a進行三倍的擴充套件。由於tile()的擴充套件遵循不變維數的原則,所以這裡的擴充套件結果要將一維張量的所有數看作一個整體,對整體進行倍增

即:[1 2 3 1 2 3 1 2 3]

  • 接下來是二維張量b = [[1, 2], [2, 3], [3, 4]] ,要擴充套件的倍數為[3, 2]。看到這裡很多朋友可能就有點嘀咕:這個 [3, 2]的擴充套件規則是個什麼意思呢?

      其實很簡單,在這裡我們要用到shape的概念,如果大家還不清楚的話就先去看看嘍,很容易理解。 我們先看擴充套件規則中的第一個3, 在這裡乘以3的具體操作流程為:開啟b最外層的一對 [ ],這時對應的shape為3(這個值和操作流程無關)。開啟之後我們很直觀的看到剩下的是[1, 2], [2, 3], [3, 4]這三個並列的項。我們就將每一項翻3倍得到——>[1, 2],[1, 2],[1, 2], [2, 3],[2, 3],[2, 3],[3, 4], [3, 4], [3, 4]。

這時再看倍數中的第二個數2,這個的意思是,在我們第一步操作結束的基礎上,對我們得到的這9個一維並列項。

應用在上面我對一維張量a進行的擴充套件操作的講解得到——>

[[1 2 1 2], [2 3 2 3],[3 4 3 4],[1 2 1 2], [2 3 2 3], [3 4 3 4], [1 2 1 2], [2 3 2 3], [3 4 3 4]]

再看下面的temp,其實就是一個由隨機陣列成的shape為(1, 3, 2)的三維張量。 在這裡,我們應用遞迴的思想解釋問題(其實就是偷個懶,大家理解了一維二維的話,我只需要講下多出來那一維,剩下的同理啦~)

假設我們得到的三維張量如下: 

[[[-0.9044834  -0.9477105 ]   [-0.48063868 -0.12450311]   [-0.09301946 -0.5487149 ]]]

  • 針對c = tf.tile(temp, [2, 1, 1]),我們可以看到倍數引數中為對應temp維度的[2, 1, 1]。

       三維張量,一共有三層的shape值(如程式碼中的[1, 3, 2]),shape的這三個值與倍數引數中的三個值一 一對應

在這裡按照拆最外層中括號的思想,2這個值其實就對應把temp拆除最外層【】-->得到:

[[-0.9044834  -0.9477105 ] ,[-0.48063868 -0.12450311] ,[-0.09301946 -0.5487149 ]]

        因為這一層shape為1,所以我們就將上面的整體翻二倍即可。(shape的1對應倍數的2)

同理,再拆一層,shape為三就要將以逗號為分隔符的三項整體分別擴充套件1倍 (shape的3對應倍數的1)

           對下一層而言,注意:如果每個並列項只含有一對 [ ],那麼我們就不再拆除他,這樣便於理解。我們在這裡就對所有的並列項做同上文對a的操作一樣,進行整體倍增即可。

大家按照這個思路自己進行一下計算,看看和結果是否匹配哦~~

[[[-0.9044834  -0.9477105 ]   [-0.48063868 -0.12450311]   [-0.09301946 -0.5487149 ]]

 [[-0.9044834  -0.9477105 ]   [-0.48063868 -0.12450311]   [-0.09301946 -0.5487149 ]]]