【入门教程】官方手把手教你如何将 TensorFlow 1 升级到 TensorFlow 2(上)

这篇博客详细介绍了如何将使用低级别 TensorFlow API 的代码升级到 TensorFlow 2.0,包括自动转换脚本、高层行为变更、模型转换和训练流程的调整。内容涵盖自动脚本的使用、资源变量、控制流变化以及如何让代码适应 2.0 的原生特性,旨在帮助开发者使代码更简洁、高效和易于维护。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

深度学习框架 TensorFlow 升级到 2.0 后,旧版本里的 TensorFlow 代码该如何升级到 2.0?本篇文章为您带来详细的 TensorFlow 升级教程,借助这份指南帮您将代码升级到 TensorFlow 2,使您的代码更加简洁、更容易维护。

↓点击观看代码迁移视频教程↓

如何将代码迁移至 TensorFlow 2 ?

重要说明:
这份文档适用于使用低级别 TensorFlow API 的用户。如果您正在使用高级别 API (tf.keras),可能无需或仅需对您的代码执行很少操作,便可以让代码完全兼容 TensorFlow 2.0。查看您的优化器默认学习速率。

      

在 TensorFlow 2.0 中,1.X 的代码不经修改也许还可运行(除了 contrib):

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

但是,这样做无法让您利用到我们对 TensorFlow 2.0 做出的许多改进。这份指南可以帮助您升级代码,让代码更加简洁、性能更好、更容易维护。

 

自动转换脚本

尝试执行文档中描述的这些变更之前,首先需运行此升级脚本

这是将您的代码升级到 TensorFlow 2.0 的第一步。但您的代码不会因此具有 2.0 的特点。您的代码仍然可以使用 tf.compat.v1 端点来访问占位符、会话、集合以及 1.x 版本的其他功能。

 

高层行为变更

如果您使用tf.compat.v1.disable_v2_behavior()让您的代码可以在 TensorFlow 2.0 中工作,那么您仍需要处理全局行为变更。重大变更有:

  • Eager executionv1.enable_eager_execution():任何隐式使用 tf.Graph 的代码都会执行失败。确保将这个代码打包进 with tf.Graph().as_default() 上下文中。

  • 资源变量v1.enable_resource_variables():一些代码可能会依赖由 TF 参考变量启用的非确定性行为。对资源变量进行写入时,它会处于锁定状态,因此可以提供更直观的一致性保证。

    • 在边缘情形中,启用资源变量可能会改变其中的行为。

    • 启用资源变量可能会创建额外的副本以及有更高的内存使用情况。

    • 可以通过将 use_resource=False 传递给 tf.Variable 构造器来禁用资源变量。

  • Tensor Shapesv1.enable_v2_tensorshape():TF 2.0 可简化张量形状的行为。您可以使用 t.shape[0],而不需要使用 t.shape[0].value。这样的变更很小,因此最好立即加以修复。有关示例,请参阅 TensorShape。

  • Control flowv1.enable_control_flow_v2():TF 2.0 中的控制流实施已得到简化,因此会有不同的图表征。如有任何问题,请提交错误报告

 

 

让代码成为 2.0 原生代码

这份指南会逐步讲解将 TensorFlow 1.x 代码转换成 TensorFlow 2.0 代码的示例。这一变更让您的代码可以获得性能优化和简化的 API 调用的优势。

 

在以下各种情况下,模式为:

1. 替换 v1.Session.run 调用

每一个 v1.Session.run 调用都应该替换为一个 Python 函数。

  • feed_dict 和 v1.placeholder 成为函数参数;

  • fetches 成为函数的返回值;

  • 在转换期间,即刻执行使用标准 Python 工具(如 pdb)让调试变得简单。

 

之后添加一个 tf.function 修饰器,这样可以更有效率地在图中运行。如需了解其工作原理的更多内容,请参阅 AutoGraph 指南 。

 

请注意:

  • v1.Session.run不同,tf.function有固定的返回签名,并且总是返回所有的输出。如果这样会导致性能问题,建议创建两个单独的函数;

  • 不需要进行tf.control_dependencies或类似的操作:tf.function 会像事先写好一样,按顺序运行。例如,tf.Variable 赋值和 tf.assert 就会自动执行。

 

2. 使用 Python 对象来追踪变量和损失

在 TF 2.0 中,强烈建议不要使用基于名称的变量追踪。请使用 Python 对象来追踪变量。

请使用tf.Variable,而不要使用v1.get_variable

每一个v1.variable_scope 应该转换为一个 Python 对象。通常是下列对象中的一个:

  • tf.keras.layers.Layer

  • tf.keras.Model

  • tf.Module

 

如果您需要聚合变量列表(如tf.Graph.get_collection(tf.GraphKeys.VARIABLES),请使用LayerModel对象中的.variables.trainable_variables属性。


这些 Layer 和 Model 类会实施多种其他属性,因此不需要全局集合。其 .losses 属性可以替代 tf.GraphKeys.LOSSES 集合。


请参阅 Keras 指南了解详情。

警告:
许多 tf.compat.v1 符号会以隐式方式使用全局集合。

3. 升级您的训练循环

使用适合您用例的最高级别 API。推荐使用 tf.keras.Model.fit 构建自己的训练循环。

这些高级别函数负责管理许多自己编写训练循环时容易漏掉的低级别细节。例如,这些函数会自动收集正则化损失,并且当调用模型时设置 training=True 参数。

4. 升级您的数据输入流水线

使用 tf.data 数据集进行数据输入。这些对象高效、可表达,与 TensorFlow 有很好的集成。

可以直接传递到 tf.keras.Model.fit 方法。

model.fit(dataset, epochs=5)

可以直接通过标准 Python 来遍历:

for example_batch, label_batch in dataset:
    break

5. 迁移 compat.v1 符号

tf.compat.v1 模块含有完整的 TensorFlow 1.x API,且具有其原始的语义。

如果此类转换安全,则 TF2 升级脚本会将符号转换成 2.0 版本中对应的符号,即脚本能够确认行为在 2.0 版本与 1.x 中完全等效(例如,脚本判断两者是同一个函数,因此将 v1.arg_max 重命名为 tf.argmax)。

升级脚本运行结束后,会留下一段代码,其中很可能多次出现 compat.v1。建议逐行检查代码,手动将其转换成 2.0 版本中对应的符号(如果有对应的符号,则会在日志中提及)。

 

 

模型转换

1. 设置

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

import tensorflow_datasets as tfds

 

2. 低级别变量和运算符

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值