1. 程式人生 > >Python取多維陣列第n維的前幾位

Python取多維陣列第n維的前幾位

現在我們有一個shape為(7352, 9, 128, 1)的numpy陣列。

想要取出第2維的前三個資料,構成新陣列(7352, 3, 128, 1)

我的思想是:將第2維資料轉置(transpose)到第一維,再用切片(slice)取出前三個資料,再轉置回去:

print("# original", input.shape)
input_transpose = input.transpose((1, 0, 2, 3))
print("# transpose", input_transpose.shape)
input_slice = input_transpose[0:3]
print("# slice", input_slice.shape)
output = input_slice.transpose((1, 0, 2, 3))
print("# output", output.shape)


其實更簡單的做法是:

print("# original", input.shape)
print("# output", input[:, [0, 1, 2]].shape)