1. 程式人生 > >Python 進階 —— 使用修飾器執行函式的引數檢查

Python 進階 —— 使用修飾器執行函式的引數檢查

引數檢查:1. 引數的個數;2. 引數的型別;3. 返回值的型別。

考慮如下的函式:

import html

def make_tagged(text, tag):
    return '<{0}>{1}</{0}>'.format(tag, html.escape(text))

顯然我們希望傳遞進來兩個引數,且引數型別/返回值型別均為str,再考慮如下的函式:

def repeat(what, count, separator) :
    return ((what + separator)*count)[:-len(separator)]

顯然我們希望傳遞進來三個引數,分別為str

intstr型別,可對返回值不做要求。

那麼我們該如何實現對上述引數要求,進行引數檢查呢?

import functools

def statically_typed(*types, return_type=None):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if len(args) > len(types):
                raise ValueError('too many arguments'
) elif len(args) < len(types): raise ValueError('too few arguments') for i, (type_, arg) in enumerate(zip(types, args)): if not isinstance(type_, arg): raise ValueError('argument {} must be of type {}'.format(i, type_.__name__)) result = func(*args, **kwargs) if
return_type is not None and not isinstance(result, return_type): raise ValueError('return value must be of type {}'.format(return_type.__name__)) return wrapper return decorator

這樣,我們便可以使用修飾器模板執行引數檢查了:

@statically_typed(str, str, return_type=str)
def make_tagged(text, tag):
    return '<{0}>{1}</{0}>'.format(tag, html.escape(text))

@statically_typed(str, int, str)
def repeat(what, count, separator):
    return ((what + separator)*count)[:-len(separator)]

注:從靜態型別語言(C/C++、Java)轉入 Python 的開發者可能比較喜歡用修飾器對函式的引數及返回值執行靜態型別檢查,但這樣做會增加 Python 程式在執行期的開銷,而編譯型語言則沒有這種執行期開銷(Python 是解釋型語言)。