1. 程式人生 > >numpy陣列中reshape和squeeze函式的使用

numpy陣列中reshape和squeeze函式的使用

參考了:http://blog.csdn.net/zenghaitao0128/article/details/78512715,作了一些自己的補充。

numpy中的reshape函式和squeeze函式是深度學習程式碼編寫中經常使用的函式,需要深入的理解。

其中,reshape函式用於調整陣列的軸和維度,而squeeze函式的用法如下,

語法:numpy.squeeze(a,axis = None)

 1)a表示輸入的陣列;  2)axis用於指定需要刪除的維度,但是指定的維度必須為單維度,否則將會報錯;  3)axis的取值可為None 或 int 或 tuple of ints, 可選。若axis為空,則刪除所有單維度的條目;  4)返回值:陣列  5) 不會修改原陣列; 作用:從陣列的形狀中刪除單維度條目,即把shape中為1的維度去掉  

舉例:

numpy的reshape和squeeze函式:

import numpy as np
e = np.arange(10)
print(e)
一維陣列:[0 1 2 3 4 5 6 7 8 9]
f = e.reshape(1,1,10)
print(f)

三維陣列:(第三個方括號裡有十個元素)

[[[0 1 2 3 4 5 6 7 8 9]]],前兩維的秩為1

g = f.reshape(1,10,1)
print(g)

三維陣列:(第二個方括號裡有十個元素)

[[[0]
  [1]
  [2]
  [3]
  [4]
  [5]
  [6]
  [7]
  [8]
  [9]]]
h = e.reshape(10,1,1)
print(h)
三維陣列:(第一個方括號裡有10個元素)
[[[0]]

 [[1]]

 [[2]]

 [[3]]

 [[4]]

 [[5]]

 [[6]]

 [[7]]

 [[8]]

 [[9]]]

利用squeeze可以把陣列中的1維度去掉(從0開始指定軸),以下為不加引數axis,去掉所有1維的軸:

m = np.squeeze(h)
print(m)

以下指定去掉第幾軸

n = np.squeeze(h,2)
print(n)
去掉第三軸,變成二維陣列,維度為(10,1):
[[0]
 [1]
 [2]
 [3]
 [4]
 [5]
 [6]
 [7]
 [8]
 [9]]

再舉一個例子:

p = np.squeeze(g,2)
print(p)

去掉第2軸,得到二維陣列,維度為(1,10):

[[0 1 2 3 4 5 6 7 8 9]]
p = np.squeeze(g,0)
print(p)

去掉第0軸,得到二維陣列,維度為(10,1):

[[0]
 [1]
 [2]
 [3]
 [4]
 [5]
 [6]
 [7]
 [8]
 [9]]

在matplotlib畫圖中,非單維的陣列在plot時候會出現問題,(1,nx)不行,但是(nx, )可以,(nx,1)也可以。

如下:

import matplotlib.pyplot as plt
squares =np.array([[1,4,9,16,25]]) 
print(squares.shape)    

square的維度為(1,5),無法畫圖:

做如下修改:

plt.plot(np.squeeze(squares))    
plt.show()

square的維度為(5,),可以畫圖:

或者做如下修改

squares1 = squares.reshape(5,1)
plt.plot(squares1)  
plt.show()

square的維度為(5,1),可以畫圖: