1. 程式人生 > >numpy學習筆記-將條件邏輯表述為陣列運算

numpy學習筆記-將條件邏輯表述為陣列運算

numpy.where函式是三元表示式x if condition else y的向量化版本。假設我們有一個布林陣列和兩個值陣列。

xarr = np.array([1.1,1.2,1.3,1.4,1.5])
yarr = np.array([2.1,2.2,2.3,2.4,2.5])
cond = np.array([True,False,True,True,False])

假設我們想要根據cond中的值選取xarr和yarr的值:當cond中的值為true時,選取xarr的值,否則從yarr中選取。列表推導式的寫法應該如下所示:

result = [(x if c else y) for
x,y,c in zip(xarr,yarr,cond)] print result
輸出結果
[1.1000000000000001, 2.2000000000000002, 1.3, 1.3999999999999999, 2.5]

但是這樣有幾個問題。
一、它對大陣列的處理速度不是很快
二、它無法用於多維陣列。
若使用np.where,則可以將該功能寫的特別簡潔:

result = np.where(cond,xarr,yarr)
print result
輸出結果
[ 1.1  2.2  1.3  1.4  2.5]

np.where第二個和第三個引數不必是陣列,也可以是標量。

如下為隨機資料組成的矩陣,將所有正值替換為2,所有負值替換為-2

arr = np.random.randn(4,4)
print arr
print np.where(arr>0,2,-2)
輸出結果
[[ 1.80262171  0.7143772   2.28177789  0.43296688]
 [-0.71345291  0.70720256  0.09209334 -1.12625402]
 [ 1.45065299  0.55110371 -0.659351   -0.41175648]
 [ 0.4328763   0.67499992 -0.53531592 -0.38604227]]
[[ 2  2  2  2]
 [-2  2  2 -2]
 [ 2  2 -2 -2]
 [ 2  2 -2 -2]]

只將正值設定為2

print np.where(arr>0,2,arr)
輸出結果
[[-1.05460459  2.          2.          2.        ]
 [ 2.          2.          2.          2.        ]
 [-0.52191645 -0.98692719  2.          2.        ]
 [-2.11448246  2.         -0.25533101 -1.07167209]]

還可以用where實現更復雜的邏輯
比如

np.where(cond1&cond2,0,
            np.where(cond1,1,
                        np.where(cond2,2,3)))

等價於

result = []
for i in range(n):
    if cond1[i] and cond2[i]:
        result.append(0)
    elif cond1[i]:
        result.append(1)
    elif cond2[i]:
        result.append(2)
    else:
        result.append(3)