[email protected]_export詳解
Tensorflow經常看到定義的函式前面加了“@tf_export”。例如,tensorflow/python/platform/app.py中有:
@tf_export('app.run') def run(main=None, argv=None): """Runs the program with an optional 'main' function and 'argv' list.""" # Define help flags. _define_help_flags() # Parse known flags. argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True) main = main or _sys.modules['__main__'].main # Call the main function, passing through any arguments # to the final program. _sys.exit(main(argv))
首先,@tf_export是一個修飾符。修飾符的本質是一個函式,不懂可以撮戳這裡。
tf_export的實現在tensorflow/python/util/tf_export.py中:
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
等號的右邊的理解分兩步:
1.functools.partial
2.ap_export
functools.partial是偏函式,它的本質簡而言之是為函式固定某些引數。如:functools.partial(FuncA, p1)的作用是把函式FuncA的第一個引數固定為p1;又如functools.partial(FuncB, key1="Hello")的作用是把FuncB中的引數key1固定為“Hello"。
functools.partial(api_export, api_name=TENSORFLOW_API_NAME)的意思是把api_export的api_name這個引數固定為TENSORFLOW_API。其中TENSORFLOW_API_NAME = 'tensorflow'。
api_export是實現了__call__()函式的類,不懂戳這裡,簡而言之是把類變得可以像函式一樣呼叫。
tf_export=unctools.partial(api_export, api_name=TENSORFLOW_API_NAME)的寫法等效於:
funcC = api_export(api_name=TENSORFLOW_API_NAME)
tf_export = funcC
對於funcC = api_export(api_name=TENSORFLOW_API_NAME),會導致__init__(api_name=TENSORFLOW_API_NAME)被呼叫:
def __init__(self, *args, **kwargs):
self._names = args
self._names_v1 = kwargs.get('v1', args)
self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
其中第4行self._api_name=kwargs.get('api_name', TENSORFLOW_API_NAME)的意思是獲取api_name這個引數,如果未檢測到該引數,則預設為TENSORFLOW_API_NAME。由此看,api_name這個引數傳進來和預設的值都是TENSORFLOW_API_NAME,最終的結果是self._api_name=TENSORFLOW_API_NAME。
然後呼叫像函式一個呼叫funcC()實際上就會呼叫__call__():
def __call__(self, func):
api_names_attr = API_ATTRS[self._api_name].names -----1
api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
delattr(undecorated_f, api_names_attr)
delattr(undecorated_f, api_names_attr_v1)
_, undecorated_func = tf_decorator.unwrap(func) -----2
self.set_attr(undecorated_func, api_names_attr, self._names) ----3
self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
return func
因此@tf_export("app.run")最終的結果是用上面這個__call__()來作為修飾器。這是一個帶引數的修飾器(真心有點複雜)!
標註1:
api_names_attr = API_ATTRS[self._api_name].names: 中的self._api_name即為__init__()中提到的TENSORFLOW_API_NAME。看看API_ATTRS中都有些什麼:
_Attributes = collections.namedtuple(
'ExportedApiAttributes', ['names', 'constants'])
# Attribute values must be unique to each API.
API_ATTRS = {
TENSORFLOW_API_NAME: _Attributes(
'_tf_api_names',
'_tf_api_constants'),
ESTIMATOR_API_NAME: _Attributes(
'_estimator_api_names',
'_estimator_api_constants')
}
collections.namedtuple()返回具有命名欄位的元組的新子類。從“ExportedApiAttributes”可推測這是用來管理已輸出的API的屬性的。
標註2:
_, undecorated_func = tf_decorator.unwrap(func)
def unwrap(maybe_tf_decorator):
"""Unwraps an object into a list of TFDecorators and a final target.
Args:
maybe_tf_decorator: Any callable object.
Returns:
A tuple whose first element is an list of TFDecorator-derived objects that
were applied to the final callable target, and whose second element is the
final undecorated callable target. If the `maybe_tf_decorator` parameter is
not decorated by any TFDecorators, the first tuple element will be an empty
list. The `TFDecorator` list is ordered from outermost to innermost
decorators.
"""
decorators = []
cur = maybe_tf_decorator
while True:
if isinstance(cur, TFDecorator):
decorators.append(cur)
elif hasattr(cur, '_tf_decorator'):
decorators.append(getattr(cur, '_tf_decorator'))
else:
break
cur = decorators[-1].decorated_target
return decorators, cur
將物件展開到tfdecorator列表和最終目標列表中。undecorated_func獲得的返回物件就是我們有@tf_export修飾的函式。
標註3:self.set_attr(undecorated_func, api_names_attr, self._names) 設定屬性。
總結:@tf_export修飾器為所修改的函式取了個名字!