1. 程式人生 > >numpy和tensorflow中的關於引數axis的正確理解

numpy和tensorflow中的關於引數axis的正確理解

當給axis賦值為0時,和採取預設值時的表現是完全不同的,從下面的程式碼就可以看出。

>>> z #大小為2×3×4的陣列
array([[[ 2,  3,  4,  8],
        [ 3,  1,  4,  1],
        [ 6,  3,  2,  6]],

       [[10,  2, 45,  2],
        [ 2,  4,  5, 10],
        [22,  4,  4,  1]]])
>>> np.sum(z,axis=0)  # axis=0
array([[12,  5, 49, 10],
       [ 5,  5,  9, 11],
       [28,  7,  6,  7]]
) >>> np.sum(z) #axis不指定,取預設值 154

剛開始學習numpy和tensorflow的朋友經常遇到類似下面這樣的一些函式:

#python
x=[[1,2],[5,1]]
x=np.array(x)
z1=np.max(x,axis=0)
z2=np.max(x,axis=1)


#tensorflow
x=tf.constant([[1.,2.],[5.,2.]])  
x=tf.shape(x)  
z1=tf.reduce_max(x,axis=0)#沿axis=0操作  
z2=tf.reduce_max(x,axis=1)#沿axis=

類似的還有argmax,sum等等函式,它們都含有一個名為axis的引數,那這個引數到底是什麼意思呢?一句話總結就是:沿著axis指定的軸進行相應的函式操作

直接看這句話可能看不懂,下面用一個最簡單的例子來說明一下。

import numpy as np
#首先,建立一個2×3維的numpy的array陣列
x=[[2,3,4],[1,2,5]]
x=np.array(x)
#然後,計算不同引數下np.max的輸出

print(np.max(x))
# 5
print(np.max(x,0))
# [2,3,5]
print(np.max(x,1))
# [4,5]

可以看到,如果不知道axis,那麼預設就是取得整個陣列的最大值,這相當於把多維陣列展開成一維,然後找到這個一維數組裡的最大值。 而當axis=0時,直觀上來看就是取得每一列的最大值,源陣列總共為2行3列,所以最終的輸出包含3個元素。 當axis=1時,就相當與是取每一行的最大值。

上面的理解方式在二維陣列還比較直觀,但是如果陣列達到3維4維甚至更高維時,就不能簡單的從行列角度出發去理解了,這時應該考慮從“軸”的角度來看。首先,明確一點,“軸”是從外向裡的,也就是說,最外層的是0軸,往內一次是1軸,2軸… 。 具體可以看下面的例子:

>>> z
array([[[ 2,  3,  4,  8],
        [ 3,  1,  4,  1],
        [ 6,  3,  2,  6]],

       [[10,  2, 45,  2],
        [ 2,  4,  5, 10],
        [22,  4,  4,  1]]])
>>> z.shape
(2, 3, 4)

可以看到,這是一個2×3×4的三位陣列,其中0軸對應第一維(2),1軸對應第二維(3),2軸對應第三維(4)。當我們指定了函式按某一軸來計算時,函式的輸出陣列的shape就是去掉當前軸的shape,如下所示。

>>> np.max(z,axis=0).shape
(3, 4)
>>> np.max(z,axis=1).shape
(2, 4)
>>> np.max(z,axis=2).shape
(2, 3)

而對於輸出陣列的每一個元素output[i][j]的值,實際上就是z[i][...][j]集合中的最大值,如下面的程式碼所示。其中當axis=0時,輸出陣列output的shape為3×4,其中output.[2][3]的值,實際上就是z[0][2][3],z[1][2][3]的最大值,也就是(6,1)中的最大值,即為output.[2][3]=6

再如axis=1時,輸出陣列output的shape為2×4,其中output.[1][2]的值,實際上就是z[1][0][2],z[1][1][2],z[1][2][2]中的最大值,也就是(45,5,4)中的最大值,即為output.[1][2]=45]

>>> np.max(z,axis=0)
array([[10,  3, 45,  8],
       [ 3,  4,  5, 10],
       [22,  4,  4,  6]])
>>> np.max(z,axis=1)
array([[ 6,  3,  4,  8],
       [22,  4, 45, 10]])
>>> np.max(z,axis=2)
array([[ 8,  4,  6],
       [45, 10, 22]])

用形式化的數學語言總結上面的過程就是: 對於大小為[i,j,k]的輸入陣列z,假設axis=0,那麼輸出矩陣output的大小就為[j,k],並且output的每一個元素的計算方式如下:

output[j,k]=maxi(z[i,j,k])output[j,k]=maxi(z[i,j,k])

如果axis=1,那麼輸出矩陣output的大小就為[i,k],並且output的每一個元素的計算方式如下:

output[i,k]=maxj(z[i,j,k])output[i,k]=maxj(z[i,j,k])

對於4維,5維甚至無限維的情況,計算方法是一樣的,你不妨自己推導一下,如果有任何問題,歡迎可以在評論中留言。

另外,對於其他的sum,argmax等等函式中的計算方法也是一樣的,只需要把函式max換成對應的函式即可,如下所示:

sum:

output[j,k]=i(z[i,j,k])output[j,k]=∑i(z[i,j,k])