1. 程式人生 > >徹底理解numpy中的axis

徹底理解numpy中的axis

時下流行人工智慧,python成為人工智慧最好的處理語言,這與python中的科學計算模組numpy是分不開的。numpy相信大都數人都知道。而在numpy中,有很多的函式都涉及到axis,numpy中的軸axis是很重要的,許多numpy的操作根據axis的取值不同,作出的操作也不相同。可以說,axis讓numpy的多維陣列變的更加靈活,但也讓numpy變得越發難以理解。因此,弄清楚axis的作用顯得尤為重要。

  • 為什麼要設計axis這個東西呢?

numpy是個多維陣列,多維陣列運算需要指定到底對哪一維操作,因此axis就是用來指定需要操作的維數。 

  •  簡單示例

先來看看一個二維陣列

>>> import numpy as np  
>>> np_data = np.array([[1,2,3],[4,5,6],[1,3,5],[2,4,6]])
>>> np_data
array([[1, 2, 3],                                                                                                                          
       [4, 5, 6],                                                                                                              
       [1, 3, 5],                                                                                                              
       [2, 4, 6]])

假設 這個陣列代表了樣本資料的特徵,其中每一行代表一個樣本的三個特徵,每一列是不同樣本的特徵。如果在分析樣本的過程中需要對每個樣本的三個特徵求和,該如何處理?

>>> np.sum(np_data,axis=1)
array([ 6, 15,  9, 12])

那如果想求每種特徵的最小值,該如何處理?

>>> np.min(np_data, axis=0)
array([1, 2, 3])

又如果想得知所有樣本所有特徵的平均值呢?

>>> np.average(np_data)
3.5
  • 重點

由此可以看出:(劃重點)通過不同的axis,numpy會沿著不同的方向進行操作:如果不設定,那麼對所有的元素操作;如果axis=0,則沿著縱軸進行操作;axis=1,則沿著橫軸進行操作。

但這只是簡單的二位陣列,如果是多維的呢?可以總結為一句話:設axis=i,則numpy沿著第i個下標變化的方向進行操作。例如剛剛的例子,可以將表示為:

data =[[a00, a01],
       [a10, a11]]

所以axis=0時,沿著第0個下標變化的方向進行操作,也就是a00->a10, a01->a11,也就是縱座標的方向,axis=1時也類似。

  • 多維陣列驗證

下面我們舉一個四維的求sum的例子來驗證一下:

>>> np_data = np.random.randint(0, 6, [4,2,3,4])
>>> np_data
array([[[[3, 5, 5, 0],                                                                                                               
         [0, 1, 2, 4],                                                                                                           
         [0, 5, 0, 5]],    # D0000  ->   D0023
                                                                                                                                                                                                                                
         [[5, 5, 0, 0],                                                                                                           
          [2, 1, 5, 0],                                                                                                           
          [1, 0, 0, 1]]],  # D0100  ->   D0123
                                                                                                                                                                                                                                                                                                                                                     
       [[[0, 5, 1, 2],                                                                                                           
         [4, 4, 2, 2],                                                                                                           
         [3, 5, 0, 1]],    # D1000  ->   D1023                                                                                                                                                                                                                        
        
        [[5, 1, 2, 1],                                                                                                           
         [2, 2, 3, 5],                                                                                                           
         [5, 3, 3, 3]]],   # D1100  ->   D1123                                                                                                                                                                                                                                                                                                                                                  

       [[[2, 4, 1, 4],                                                                                                           
         [1, 4, 1, 4],                                                                                                           
         [4, 5, 0, 2]],    # D2000  ->   D2023                                                                                                                                                                                                                         
        
        [[2, 5, 5, 1],                                                                                                           
         [5, 3, 0, 2],                                                                                                           
         [4, 0, 1, 3]]],   # D2100  ->   D2123                                                                                                                                                                                                                                                                                                                                                  

       [[[1, 3, 4, 5],                                                                                                           
         [0, 2, 5, 4],                                                                                                           
         [2, 3, 5, 3]],    # D3000  ->   D3023                                                                                                                                                                                                                           
       
        [[2, 2, 2, 2],                                                                                                           
         [3, 2, 1, 3],                                                                                                           
         [0, 3, 0, 1]]]])  # D3100  ->   D3123
  • 當axis=0時,numpy驗證第0維的方向來求和,也就是第一個元素值=D0000+D1000+D2000+D3000=3+0+2+1=6,第二個元素=D0001+D1001+D2001+D3001=5+5+4+3=17,同理可得最後的結果如下:
>>> np_data.sum(axis=0)
array([[[ 6, 17, 11, 11],                                                                                                       
        [ 5, 11, 10, 14],                                                                                                       
        [ 9, 18,  5, 11]],                                                                                                                                                                                                                             
      
      [[14, 13,  9,  4],                                                                                                       
       [12,  8,  9, 10],                                                                                                       
       [10,  6,  4,  8]]])
  • 當axis=3時,numpy驗證第3維的方向來求和,也就是第一個元素值=D0000+D0001+D0002+D0003=3+5+5+0=13,第二個元素=D0010+D0011+D0012+D0013=0+1+2+4=7,同理可得最後的結果如下:
>>> data.sum(axis=3)
array([[[13,  7, 10],                                                                                                               
        [10,  8,  2]],                                                                                                                                                                                                                                 
    
       [[ 8, 12,  9],                                                                                                           
        [ 9, 12, 14]],                                                                                                                                                                                                                                 

       [[11, 10, 11],                                                                                                           
        [13, 10,  8]],                                                                                                                                                                                                                                 

       [[13, 11, 13],                                                                                                           
        [ 8,  9,  4]]])
  • 使用axis的相關函式

在numpy中,使用的axis的地方非常多,處理上文已經提到的average、max、min、sum,比較常見的還有sort和prod,下面分別舉幾個例子看一下:

  • sort
>>> np_data = np.random.randint(0, 4, [2,2,3])
>>> np_data
array([[[2, 0, 2],                                                                                                              
        [1, 1, 1]],                                                                                                                                                                                                                                    

       [[3, 2, 1],                                                                                                              
        [2, 2, 0]]])
>>> np.sort(np_data)  ## 預設對最大的axis進行排序,這裡即是axis=2
array([[[0, 2, 2],                                                                                                              
        [1, 1, 1]],                                                                                                                                                                                                                                    

       [[1, 2, 3],                                                                                                              
        [0, 2, 2]]])
>>> np.sort(np_data, axis=0) #沿著第0維進行排序,原先的D000->D100
array([[[2, 0, 1],                                                                                                              
        [1, 1, 0]],                                                                                                                                                                                                                                    

       [[3, 2, 2],                                                                                                              
        [2, 2, 1]]])
>>> np.sort(np_data, axis=1)  # 沿著第1維進行排序
array([[[1, 0, 1],                                                                                                              
        [2, 1, 2]],                                                                                                                                                                                                                                    

       [[2, 2, 0],                                                                                                              
        [3, 2, 1]]])
>>> np.sort(np_data, axis=2)  # 沿著第2維進行排序(和預設對最大的axis進行排序一樣結果)
array([[[0, 2, 2],                                                                                                              
        [1, 1, 1]],                                                                                                                                                                                                                                    

       [[1, 2, 3],                                                                                                              
        [0, 2, 2]]]) 
>>> np.sort(np_data, axis=None)  # 對全部資料進行排序
array([0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3])
  • prod(即product,乘積)
 >>> np.prod([[1.,2.],[3.,4.]])
 24.0

 >>> np.prod([[1.,2.],[3.,4.]], axis=1)
 array([  2.,  12.])

 >>> np.prod([[1.,2.],[3.,4.]], axis=0)
 array([ 3.,  8.])

簡要概括axis的用法是:axis = i,則numpy沿著第 i 個下標變化的方向進行操作

相信通過上面的講解與例子,你應該對axis有了比較清楚的瞭解。只有真正理解axis的原理,才能對numpy操作遊刃有餘。