1. 程式人生 > >pytorch 與 numpy 的陣列廣播機制

pytorch 與 numpy 的陣列廣播機制

numpy 的文件提到陣列廣播機制為:
When operating on two arrays, NumPy compares their shapes element-wise. It starts with the trailing dimensions, and works its way forward. Two dimensions are compatible when
they are equal, or

  1. one of them is 1
  2. If these conditions are not met, a ValueError: frames are not aligned exception is thrown, indicating that the arrays have incompatible shapes. The size of the resulting array is the maximum size along each dimension of the input arrays.

翻譯過來就是,從兩個陣列地末尾開始算起,若軸長相等或者其中一個地維度為1,則認為是廣播相容的,否則是不相容地。廣播相容的陣列會在缺失的維度和長度為1的維度上進行。

例如:

a.shape + b.shape c.shape
(4, 1) + (1) --> (4, 1)
(4, 1) + (3,) --> (4, 3)
(2, 3, 4) + (1, 4) --> (2, 3, 4)
(2, 3, 4) + (3, 1) --> (2, 3, 4)
(2, 3, 4) + (2, 1, 1) --> (2, 3, 4)
(2, 3, 4) + (3, ) X
(4, 3) + (4,) X
(4, 3) + (3,) --> (4, 3)
(4, 3) + (3) --> (4, 3)