Tensorflow-@tf_export详解

本文深入探讨了TensorFlow中@tf_export装饰器的工作原理,解释了如何为函数指定名称,涉及装饰器、偏函数及类的调用机制。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

@tf_export为函数取了个名字!


Tensorflow经常看到定义的函数前面加了“@tf_export”。例如,tensorflow/python/platform/app.py中有:

@tf_export('app.run')
def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""

  # Define help flags.
  _define_help_flags()

  # Parse known flags.
  argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)

  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(argv))

 

首先,@tf_export是一个修饰符。修饰符的本质是一个函数,不懂可以撮戳这里

tf_export的实现在tensorflow/python/util/tf_export.py中:

tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)

  等号的右边的理解分两步:

  1.functools.partial

  2.api_export

   functools.partial是偏函数,它的本质简而言之是为函数固定某些参数。如:functools.partial(FuncA, p1)的作用是把函数FuncA的第一个参数固定为p1;又如functools.partial(FuncB, key1="Hello")的作用是把FuncB中的参数key1固定为“Hello"。  

    functools.partial(api_export, api_name=TENSORFLOW_API_NAME)的意思是把api_export的api_name这个参数固定为TENSORFLOW_API。其中TENSORFLOW_API_NAME = 'tensorflow'。

  api_export是实现了__call__()函数的类,不懂戳这里,简而言之是把类变得可以像函数一样调用。

  tf_export=unctools.partial(api_export, api_name=TENSORFLOW_API_NAME)的写法等效于:  

funcC = api_export(api_name=TENSORFLOW_API_NAME)
tf_export = funcC

对于funcC = api_export(api_name=TENSORFLOW_API_NAME),会导致__init__(api_name=TENSORFLOW_API_NAME)被调用:

  def __init__(self, *args, **kwargs):
    self._names = args
    self._names_v1 = kwargs.get('v1', args)
    self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
    self._overrides = kwargs.get('overrides', [])
    self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)

  其中第4行self._api_name=kwargs.get('api_name', TENSORFLOW_API_NAME)的意思是获取api_name这个参数,如果未检测到该参数,则默认为TENSORFLOW_API_NAME。由此看,api_name这个参数传进来和默认的值都是TENSORFLOW_API_NAME,最终的结果是self._api_name=TENSORFLOW_API_NAME。

  然后调用像函数一样调用funcC()实际上就会调用__call__():

  def __call__(self, func):
    api_names_attr = API_ATTRS[self._api_name].names       -----1
    api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      delattr(undecorated_f, api_names_attr)
      delattr(undecorated_f, api_names_attr_v1)

    _, undecorated_func = tf_decorator.unwrap(func)       -----2
    self.set_attr(undecorated_func, api_names_attr, self._names)  ----3
    self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
    return func

因此@tf_export("app.run")最终的结果是用上面这个__call__()来作为修饰器。这是一个带参数的修饰器(真心有点复杂)!

 标注1:

  api_names_attr = API_ATTRS[self._api_name].names: 中的self._api_name即为__init__()中提到的TENSORFLOW_API_NAME。看看API_ATTRS中都有些什么:


_Attributes = collections.namedtuple(
    'ExportedApiAttributes', ['names', 'constants'])

# Attribute values must be unique to each API.
API_ATTRS = {
    TENSORFLOW_API_NAME: _Attributes(
        '_tf_api_names',
        '_tf_api_constants'),
    ESTIMATOR_API_NAME: _Attributes(
        '_estimator_api_names',
        '_estimator_api_constants')
}

collections.namedtuple()返回具有命名字段的元组的新子类。从“ExportedApiAttributes”可推测这是用来管理已输出的API的属性的。

  标注2:

    _, undecorated_func = tf_decorator.unwrap(func)

def unwrap(maybe_tf_decorator):
  """Unwraps an object into a list of TFDecorators and a final target.

  Args:
    maybe_tf_decorator: Any callable object.

  Returns:
    A tuple whose first element is an list of TFDecorator-derived objects that
    were applied to the final callable target, and whose second element is the
    final undecorated callable target. If the `maybe_tf_decorator` parameter is
    not decorated by any TFDecorators, the first tuple element will be an empty
    list. The `TFDecorator` list is ordered from outermost to innermost
    decorators.
  """
  decorators = []
  cur = maybe_tf_decorator
  while True:
    if isinstance(cur, TFDecorator):
      decorators.append(cur)
    elif hasattr(cur, '_tf_decorator'):
      decorators.append(getattr(cur, '_tf_decorator'))
    else:
      break
    cur = decorators[-1].decorated_target
  return decorators, cur

将对象展开到tfdecorator列表和最终目标列表中。undecorated_func获得的返回对象就是我们有@tf_export修饰的函数。

标注3:self.set_attr(undecorated_func, api_names_attr, self._names) 设置属性。

 

总结:@tf_export修饰器为所修饰的函数取了个名字!

 

 

参考资源链接:[ANACONDA+Cuda/cuDNN+Tensorflow-gpu与Keras安装详解PPT](https://wenku.youkuaiyun.com/doc/6y7oby6j13?utm_source=wenku_answer2doc_content) 要在Anaconda中安装并配置TensorFlow-gpu和Keras以支持GPU,首先确保你的系统已经安装了兼容的NVIDIA驱动程序。然后,参考以下步骤进行操作: 1. 安装Anaconda:访问Anaconda官网下载并安装适合你操作系统的Anaconda版本。Anaconda能够帮助你创建独立的Python环境,便于管理不同项目依赖。 2. 安装CUDA:根据你的NVIDIA GPU型号和操作系统,从NVIDIA官网下载对应版本的CUDA Toolkit,并进行安装。安装完成后,通过命令行工具(如`nvcc --version`)检查CUDA是否安装成功。 3. 安装cuDNN:从NVIDIA官方cuDNN下载页面获取cuDNN库,解压后将包含的库文件复制到CUDA安装目录下的相应文件夹内。同时,需要设置环境变量(如`export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/cudnn/lib`),确保系统能够找到cuDNN库。 4. 创建并激活TensorFlow-gpu环境:打开Anaconda Prompt,运行`conda create -n tf_gpu python=3.7`创建一个新环境,然后使用`conda activate tf_gpu`激活它。 5. 安装TensorFlow-gpu:在激活的环境中,使用命令`pip install tensorflow-gpu`安装TensorFlow-gpu。这个命令会自动安装与你CUDA版本兼容的TensorFlow-gpu版本。 6. 安装Keras:Keras作为TensorFlow的高级API,可以通过简单地运行`pip install keras`来安装。确保在安装Keras之前已经正确设置了CUDA和cuDNN。 在安装过程中,可能会遇到与CUDA版本不兼容的问题,建议查阅TensorFlow官方文档中的GPU支持矩阵,选择与你的CUDA版本相兼容的TensorFlow-gpu版本进行安装。 完成以上步骤后,你可以通过在Python中运行`import tensorflow as tf`和`import keras`来验证安装是否成功。如果系统能够正确识别GPU,则在`tf.Session()`时会看到GPU被使用的信息。 建议在安装后查看《ANACONDA+Cuda/cuDNN+Tensorflow-gpu与Keras安装详解PPT》,这份资源不仅提供了详细的安装步骤,还通过实例展示了如何利用这些工具进行深度学习开发,是深入理解和实践的宝贵资料。 参考资源链接:[ANACONDA+Cuda/cuDNN+Tensorflow-gpu与Keras安装详解PPT](https://wenku.youkuaiyun.com/doc/6y7oby6j13?utm_source=wenku_answer2doc_content)
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值