1. 程式人生 > >numpy中的squeeze()函式

numpy中的squeeze()函式

numpy.squeeze(a, axis=None)

squeeze()函式的功能是:從矩陣shape中,去掉維度為1的。例如一個矩陣是的shape是(5, 1),使用過這個函式後,結果為(5,)。

引數
a是輸入的矩陣
axis : 選擇shape中的一維條目的子集。如果在shape大於1的情況下設定axis,則會引發錯誤。

栗子
要使用numpy先匯入numpy庫
import numpy as np


>>> x = np.array([[[0], [1], [2]]])
>>> x.shape
(1, 3, 1)
>>> np.squeeze(x).shape
(3
,) >>> np.squeeze(x, axis=(2,)).shape (1, 3)

squeeze()的原始碼

def squeeze(a, axis=None):
    """
    Remove single-dimensional entries from the shape of an array.
    Parameters
    ----------
    a : array_like
        Input data.
    axis : None or int or tuple of ints, optional
        .. versionadded:: 1.7.0
        Selects a subset of the single-dimensional entries in the
        shape. If an axis is selected with shape entry greater than
        one, an error is raised.
    Returns
    -------
    squeezed : ndarray
        The input array, but with all or a subset of the
        dimensions of length 1 removed. This is always `a` itself
        or a view into `a`.
    Examples
    --------
    >>> x = np.array([[[0], [1], [2]]])
    >>> x.shape
    (1, 3, 1)
    >>> np.squeeze(x).shape
    (3,)
    >>> np.squeeze(x, axis=(2,)).shape
    (1, 3)
    """
try: squeeze = a.squeeze except AttributeError: return _wrapit(a, 'squeeze') try: # First try to use the new axis= parameter return squeeze(axis=axis) except TypeError: # For backwards compatibility return squeeze()