Python中np.sum()對axis的個人理解,超詳細
你們討論的axis=0和1並不是簡單的行和列,axis=0表示的是第一個維度,在第一個維度上的元素間進行求和、比較大小,axis=1表示的是第二個維度,在第二個維度上的元素間進行求和、比較大小。一個維度的元素並不總是單值,有時候是一個數組或矩陣等等,這時候就要在對應位置上再進行求和、比較大小等等。
如果你不想知道其原理,簡單的記作:axis=0表示對行(上下)進行操作計算,axis=1(左右)表示對列進行操作計算,也是沒有問題的。
首先看一下什麼叫做維度,一個矩陣的維度大家都知道是二維。包含行和列。
接著看下面這個陣列a,這個陣列可以看作是一個二維陣列,每一維中又同時包含一個矩陣。運用下你的空間想象能力。所以這個陣列是三維的。
>>> import numpy as np
>>> a = np.array([[[1,2,3,2],[1,2,3,1],[2,3,4,1]],[[1,0,2,0],[2,1,2,0],[2,1,1,1]]])
>>> a
array([[[1, 2, 3, 2],
[1, 2, 3, 1],
[2, 3, 4, 1]],
[[1, 0, 2, 0],
[2, 1, 2, 0],
[2, 1, 1, 1]]])
由下面我們可以看到a的維度是3,其實有個簡便的看法,你看小括號旁有幾個中括號 [ 就是幾維,簡單吧。
>>> a.ndim
3
>>> a.shape
(2, 3, 4)
當我們想要定位到某個元素時,需要表示為 ,容易理解,三維所以需要三層括號。其中 i 表示第一維,j 表示第二維,k 表示第三維。
axis取多少,就表明在哪個維度上求和。
- axis=None 表示對所有元素求和。
- axis=0 表示在第1個維度上求和。
- axis=1 表示在第2個維度上求和。
- 以此類推…
二維陣列
讓我們從簡單的開始,先來看二維陣列的求和,假設有一個2x2的矩陣b,也就是一個二維陣列。第一個維度包含兩個陣列,每個陣列包含兩個值。
>>> b = np.array([[1, 2], [2, 3]])
>>> b
array([[1, 2],
[2, 3]])
>>> b.shape
(2, 2)
當axis=0
時,表示在第1個維度上求和,也就是在第一個維度上的元素間的求和。在這裡也就是行與行之間進行求和。因為第一個維度中包含兩個陣列,這兩個陣列也就是矩陣的兩行,發揮你的想象力。接著我們來算一下,行與行之間進行求和,首先是第一行的第一個元素與第二行的第一個元素求和(1+2)=3,然後是(2+3)=5。有下面可以看到計算正確。且可以看到輸出的shape是(1,2),也就是一行兩列。注意這一點,後面會繼續說明。直觀理解是對列進行求和。
>>> np.sum(b,axis=0)
array([3, 5])
>>> np.sum(b,axis=0).shape
(2,)
如果我們用公式來表達呢:
好理解嗎?當你習慣後,你會發現用公式會比直覺好理解的多,且更容易泛化。當資料是高維的時候,你的直覺可能就不夠用了,這時用公式你會發現會非常簡單。 表示第一個維度,axis=0表示在第一個維度上進行求和,所以公式表達就是對 進行求和, 為結果的索引。所以最後的輸出的shape是刪掉了第一個維度的shape,這裡刪掉了2,結果的shape是(1,2),也就是一行兩列。好好理解公式。
axis=1
時,在第二個維度上進行求和,是第二個維度的元素間的求和。在這裡也就是列之間的求和,首先是(1+2)=3,然後(2+3)=5。
>>> np.sum(b,axis=1)
array([3, 5])
公式,對 進行求和,不變:
三維陣列
當我們熟悉公式後,計算高維資料就不用直覺來推了,直接套公式。
axis=0
,對 進行求和,其餘為結果的索引。求和公式為:
輸出的shape自然也就是去掉 的,為(j, k),這裡就是(3,4)
例如s[1,1]=(1+1)=2,結果正確,其餘的可以自己驗算。
>>> np.sum(a,axis=0)
array([[2, 2, 5, 2],
[3, 3, 5, 1],
[4, 4, 5, 2]])
>>> np.sum(a,axis=0).shape
(3, 4)
axis=1
,對 進行求和,直接把上面的公式中的 改為 就行了。求和公式為:
還要注意輸出的shape等於原來的減去的,也就是(2,4)
>>> np.sum(a,axis=1)
array([[ 4, 7, 10, 4],
[ 5, 2, 5, 1]])
>>> np.sum(a,axis=1).shape
(2, 4)
axis=2
就不寫了,與上面同理。
axis為元組
還有一個問題,上面討論的是axis取值為單值的情況,當axis等於一個元組的時候呢?例如axis=(0,1)
其實和上面是一樣的道理,axis等於什麼,就在公式中對什麼求和就行了,例如axis=(0,1),也就是在第一個維度和第二個維度上進行求和,公式中就是對 , , 進行求和,其餘的則是結果的索引。
公式:
輸出的shape為原來的shape減去 i 和 j 的shape就行
>>> np.sum(a,axis=(0, 1))
array([ 9, 9, 15, 5])
>>> np.sum(a,axis=(0, 1)).shape
(4,)
對三維陣列a來說,axis=(0,1,2)也就等於axis=None,對所有元素進行求和。
那麼對於高位陣列,同樣也只要套公式就行了。
剛開始看可能有點繞,有任何問題可以留言交流。