TensorFlow模型保存和提取方法

本文详细介绍TensorFlow中模型的保存与加载方法,包括使用tf.train.Saver类保存和加载模型的具体步骤,如何重命名变量,加载滑动平均模型以及通过convert_variables_to_constants函数将变量及其取值作为常量保存。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow模型保存和提取方法

原创 2017年06月01日 11:25:25
  • 7004

一、TensorFlow模型保存和提取方法

1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt"),实际在这个文件目录下会生成4个人文件:

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

2. 加载这个已保存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt"),加载模型的代码中也要定义TensorFlow计算图上的所有运算并声明一个tf.train.Saver类,不同的是加载模型时不需要进行变量的初始化,而是将变量的取值通过保存的模型加载进来,注意加载路径的写法。若不希望重复定义计算图上的运算,可直接加载已经持久化的图,saver =tf.train.import_meta_graph("Model/model.ckpt.meta")

3.tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名},saver = tf.train.Saver({"v1":u1, "v2": u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。

4. 上一条做的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典。

此外,通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中。

二、TensorFlow程序实现

[python]  view plain  copy
 
  1. # 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行  
  2. # 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确  
  3.   
  4.   
  5. # Part1: 通过tf.train.Saver类实现保存和载入神经网络模型  
  6.   
  7. # 执行本段程序时注意当前的工作路径  
  8. import tensorflow as tf  
  9.   
  10. v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
  11. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
  12. result = v1 + v2  
  13.   
  14. saver = tf.train.Saver()  
  15.   
  16. with tf.Session() as sess:  
  17.     sess.run(tf.global_variables_initializer())  
  18.     saver.save(sess, "Model/model.ckpt")  
  19.   
  20.   
  21. # Part2: 加载TensorFlow模型的方法  
  22.   
  23. import tensorflow as tf  
  24.   
  25. v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
  26. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
  27. result = v1 + v2  
  28.   
  29. saver = tf.train.Saver()  
  30.   
  31. with tf.Session() as sess:  
  32.     saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"  
  33.     print(sess.run(result)) # [ 3.]  
  34.   
  35.   
  36. # Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图  
  37.   
  38. import tensorflow as tf  
  39.   
  40. saver = tf.train.import_meta_graph("Model/model.ckpt.meta")  
  41.   
  42. with tf.Session() as sess:  
  43.     saver.restore(sess, "./Model/model.ckpt") # 注意路径写法  
  44.     print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]  
  45.   
  46.   
  47. # Part4: tf.train.Saver类也支持在保存和加载时给变量重命名  
  48.   
  49. import tensorflow as tf  
  50.   
  51. # 声明的变量名称name与已保存的模型中的变量名称name不一致  
  52. u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")  
  53. u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")  
  54. result = u1 + u2  
  55.   
  56. # 若直接生命Saver类对象,会报错变量找不到  
  57. # 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}  
  58. # 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中  
  59. saver = tf.train.Saver({"v1": u1, "v2": u2})  
  60.   
  61. with tf.Session() as sess:  
  62.     saver.restore(sess, "./Model/model.ckpt")  
  63.     print(sess.run(result)) # [ 3.]  
  64.   
  65.   
  66. # Part5: 保存滑动平均模型  
  67.   
  68. import tensorflow as tf  
  69.   
  70. v = tf.Variable(0, dtype=tf.float32, name="v")  
  71. for variables in tf.global_variables():  
  72.     print(variables.name) # v:0  
  73.   
  74. ema = tf.train.ExponentialMovingAverage(0.99)  
  75. maintain_averages_op = ema.apply(tf.global_variables())  
  76. for variables in tf.global_variables():  
  77.     print(variables.name) # v:0  
  78.                           # v/ExponentialMovingAverage:0  
  79.   
  80. saver = tf.train.Saver()  
  81.   
  82. with tf.Session() as sess:  
  83.     sess.run(tf.global_variables_initializer())  
  84.     sess.run(tf.assign(v, 10))  
  85.     sess.run(maintain_averages_op)  
  86.     saver.save(sess, "Model/model_ema.ckpt")  
  87.     print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905]  
  88.   
  89.   
  90. # Part6: 通过变量重命名直接读取变量的滑动平均值  
  91.   
  92. import tensorflow as tf  
  93.   
  94. v = tf.Variable(0, dtype=tf.float32, name="v")  
  95. saver = tf.train.Saver({"v/ExponentialMovingAverage": v})  
  96.   
  97. with tf.Session() as sess:  
  98.     saver.restore(sess, "./Model/model_ema.ckpt")  
  99.     print(sess.run(v)) # 0.0999999  
  100.   
  101.   
  102. # Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典  
  103.   
  104. import tensorflow as tf  
  105.   
  106. v = tf.Variable(0, dtype=tf.float32, name="v")  
  107. # 注意此处的变量名称name一定要与已保存的变量名称一致  
  108. ema = tf.train.ExponentialMovingAverage(0.99)  
  109. print(ema.variables_to_restore())  
  110. # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}  
  111. # 此处的v取自上面变量v的名称name="v"  
  112.   
  113. saver = tf.train.Saver(ema.variables_to_restore())  
  114.   
  115. with tf.Session() as sess:  
  116.     saver.restore(sess, "./Model/model_ema.ckpt")  
  117.     print(sess.run(v)) # 0.0999999  
  118.   
  119.   
  120. # Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中  
  121.   
  122. import tensorflow as tf  
  123. from tensorflow.python.framework import graph_util  
  124.   
  125. v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
  126. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
  127. result = v1 + v2  
  128.   
  129. with tf.Session() as sess:  
  130.     sess.run(tf.global_variables_initializer())  
  131.     # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分  
  132.     graph_def = tf.get_default_graph().as_graph_def()  
  133.     output_graph_def = graph_util.convert_variables_to_constants(sess,  
  134.                                                         graph_def, ['add'])  
  135.   
  136.     with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f:  
  137.         f.write(output_graph_def.SerializeToString())  
  138.   
  139.   
  140. # Part9: 载入包含变量及其取值的模型  
  141.   
  142. import tensorflow as tf  
  143. from tensorflow.python.platform import gfile  
  144.   
  145. with tf.Session() as sess:  
  146.     model_filename = "Model/combined_model.pb"  
  147.     with gfile.FastGFile(model_filename, 'rb') as f:  
  148.         graph_def = tf.GraphDef()  
  149.         graph_def.ParseFromString(f.read())  
  150.   
  151.     result = tf.import_graph_def(graph_def, return_elements=["add:0"])  
  152.     print(sess.run(result)) # [array([ 3.], dtype=float32)]  

转载于:https://www.cnblogs.com/developer-ios/p/7855444.html

<think>首先,用户的问题是关于TensorFlow模型保存方法代码示例的。根据系统指令,我需要: 1. 所有行内数学表达式使用$...$格式。 2. 独立公式使用$$...$$格式并单独成段。 3. 确保LaTeX语法正确。 4. 使用中文回答。 5. 生成相关问题。 6. 在回答中引用的段落末尾自然地添加引用标识。 7. 回答结构清晰,帮助用户逐步解决问题,保证回答真实可靠。 8. 参考站内引用,但用户提供的引用是关于TensorFlow模型保存的,所以我应该基于这些信息来构建回答。 用户提供的引用内容: - [^1]: 讨论了PyTorchTensorFlow模型保存格式。 - [^2]: 详细介绍了TensorFlow pb模型的常用处理方法。 - [^3]: 关于PHP调用TensorFlow模型保存提取方法。 - [^4]: 展示了saver.save方法的输出。 用户的问题: "请问TensorFlow 模型保存方法及代码示例" 所以,我需要聚焦于TensorFlow模型保存方法,并提供代码示例。 关键点: - TensorFlow有多种保存格式:SavedModel, HDF5 (.h5), .pb, .tflite等。 - 常见方法:使用tf.train.Saver保存为检查点文件(.ckpt),或使用tf.saved_model保存为SavedModel格式。 基于引用[2][3]: - 保存模型可以使用tf.train.Saver保存为.ckpt文件。 - 或者保存为.pb文件(Protocol Buffers格式)。 - 还有SavedModel格式,这是TensorFlow推荐的格式。 代码示例: 1. 使用tf.train.Saver保存检查点文件: ```python import tensorflow as tf # 定义模型 model = tf.keras.Sequential([...]) model.compile(...) model.fit(...) # 创建Saver对象 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 训练模型... saver.save(sess, "model.ckpt") # 保存检查点 ``` 这会生成文件:model.ckpt.data-00000-of-00001, model.ckpt.index, model.ckpt.meta。 2. 保存为SavedModel格式: ```python model.save("saved_model") # 对于tf.keras模型 # 或者使用tf.saved_model.save ``` 3. 保存为HDF5格式(.h5): ```python model.save("model.h5") # 保存整个模型为HDF5文件 ``` 4. 保存为.pb文件(冻结图): 需要先冻结模型,使用tf.graph_util.convert_variables_to_constants。 示例代码: ```python from tensorflow.python.framework import graph_util # 假设sess是会话 output_names = [n.name for n in tf.get_default_graph().as_graph_def().node] constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names) with tf.gfile.GFile("frozen_model.pb", "wb") as f: f.write(constant_graph.SerializeToString()) ``` 回答结构: - 清晰结构:逐步介绍不同方法。 - 使用中文。 - 在适当的地方添加引用标识。例如,提到保存方法时,引用[^2]或[^3]。 - 最后生成相关问题。 生成相关问题: - 基于主题:TensorFlow模型保存、加载、优化、转换等。 例如: 1. 如何加载保存TensorFlow模型? 2. 不同保存格式的区别是什么? 3. 如何优化TensorFlow模型大小? 确保LaTeX:如果涉及数学,如损失函数,用$...$。但在这个上下文中,可能不需要太多数学,除非讨论模型结构。 现在,构建回答大纲: 1. 介绍TensorFlow模型保存的重要性。 2. 介绍主要保存方法: a. 检查点文件(.ckpt)使用tf.train.Saver。 b. SavedModel格式(推荐)。 c. HDF5格式(.h5)。 d. 冻结图(.pb)。 3. 为每种方法提供代码示例。 4. 比较格式。 5. 添加引用。 6. 结束生成相关问题。</think>### TensorFlow 模型保存方法及代码示例 在 TensorFlow 中,保存训练好的模型是深度学习工作流的关键步骤,它确保模型可以复用、部署或共享。TensorFlow 支持多种保存格式,包括检查点文件(.ckpt)、SavedModel、HDF5 (.h5) 冻结图 (.pb)。每种格式适用于不同场景,例如 SavedModel 是 TensorFlow 推荐的格式,支持跨平台部署;而 .ckpt 文件适合训练中断恢复[^2]。下面我将逐步介绍主要保存方法,并提供代码示例(基于 TensorFlow 2.x)。 #### 1. **保存为检查点文件(.ckpt)** 检查点文件包含模型的权重、计算图元数据,适合训练过程中保存中间状态,便于恢复训练。使用 `tf.train.Saver` 类实现保存(在 TensorFlow 1.x 中常见,但 2.x 也兼容)。 **代码示例:** ```python import tensorflow as tf # 定义简单模型(示例为线性回归) x = tf.placeholder(tf.float32, shape=[None, 1], name='input') y_true = tf.placeholder(tf.float32, shape=[None, 1], name='label') W = tf.Variable(tf.zeros([1, 1]), name='weight') b = tf.Variable(tf.zeros([1]), name='bias') y_pred = tf.matmul(x, W) + b loss = tf.reduce_mean(tf.square(y_true - y_pred)) optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss) # 创建 Saver 对象 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 训练模型(示例数据) for _ in range(100): sess.run(optimizer, feed_dict={x: [[1], [2]], y_true: [[2], [4]]}) # 保存模型到检查点文件 saver.save(sess, "model.ckpt") # 生成文件: model.ckpt.data-00000-of-00001, model.ckpt.index, model.ckpt.meta[^4] ``` - **输出文件**: `model.ckpt.data-00000-of-00001`(权重)、`model.ckpt.index`(索引)、`model.ckpt.meta`(计算图)。 - **特点**: 支持增量保存,但需手动管理会话。适合训练中断恢复场景[^3]。 #### 2. **保存为 SavedModel 格式** SavedModel 是 TensorFlow 的标准格式,包含完整模型(计算图、权重签名),易于部署到 TensorFlow Serving 或移动端。使用 `tf.saved_model.save` 或 Keras API 保存。 **代码示例(使用 Keras API):** ```python import tensorflow as tf # 定义并训练 Keras 模型(示例为简单分类模型) model = tf.keras.Sequential([ tf.keras.layers.Dense(10, activation='relu', input_shape=(784,)), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=5) # 假设 x_train, y_train 是训练数据 # 保存为 SavedModel 格式 model.save("saved_model") # 生成目录: saved_model/assets, saved_model/variables, saved_model/saved_model.pb ``` - **输出文件**: 目录包含 `saved_model.pb`(计算图)、`variables/`(权重) `assets/`(辅助文件)。 - **特点**: 支持签名(如输入输出定义),便于推理。是 TensorFlow 2.x 推荐格式[^1]。 #### 3. **保存为 HDF5 格式(.h5)** HDF5 格式将整个模型(架构、权重优化器状态)保存为单个文件,适合 Keras 模型的快速保存加载。 **代码示例:** ```python import tensorflow as tf # 定义并训练 Keras 模型 model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=5) # 保存为 HDF5 文件 model.save("model.h5") # 生成文件: model.h5 ``` - **输出文件**: 单个 `.h5` 文件包含所有模型信息。 - **特点**: 文件紧凑,加载方便,但仅适用于 Keras 模型,不支持 TensorFlow Serving[^1]。 #### 4. **保存为冻结图(.pb)** 冻结图将模型权重嵌入计算图,生成单个 `.pb` 文件,适合部署到移动端或嵌入式设备。需先“冻结”变量为常量。 **代码示例:** ```python import tensorflow as tf from tensorflow.python.framework import graph_util # 定义模型(同检查点示例) x = tf.placeholder(tf.float32, shape=[None, 1], name='input') y_pred = tf.matmul(x, W) + b # 假设 W, b 已定义 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 训练模型... # 冻结图:将变量转换为常量 output_names = [y_pred.op.name] # 输出节点名 constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def, output_names) # 保存为 .pb 文件 with tf.gfile.GFile("frozen_model.pb", "wb") as f: f.write(constant_graph.SerializeToString()) ``` - **输出文件**: `frozen_model.pb`(单个文件)。 - **特点**: 文件小、加载快,但无法修改权重。常用于 TensorFlow Lite 转换[^2]。 #### 格式比较与最佳实践 | 格式 | 文件类型 | 适用场景 | 优点 | 缺点 | |---------------|------------------|-----------------------------------|-------------------------------|-------------------------------| | 检查点 (.ckpt) | 多个文件 | 训练中断恢复 | 支持增量保存 | 需手动管理会话[^3] | | SavedModel | 目录 | 部署(TensorFlow Serving) | 完整模型、跨平台兼容 | 文件较大 | | HDF5 (.h5) | 单个文件 | Keras 模型快速保存 | 简单易用 | 不支持签名定义[^1] | | 冻结图 (.pb) | 单个文件 | 移动端部署、TensorFlow Lite | 文件小、推理快 | 无法修改权重[^2] | **最佳实践**: - 训练中: 使用 `.ckpt` 定期保存检查点。 - 部署: 优先使用 SavedModel 格式。 - 移动端: 转换为 `.pb` 或 TensorFlow Lite(.tflite)。 - 避免常见坑点: 确保输入输出节点名正确,加载时检查 TensorFlow 版本兼容性[^1][^2]。 #### 模型加载示例(简要) 加载模型保存对应: - 加载 `.ckpt`: `saver.restore(sess, "model.ckpt")` - 加载 SavedModel: `tf.saved_model.load("saved_model")` - 加载 `.h5`: `tf.keras.models.load_model("model.h5")` - 加载 `.pb`: 使用 `tf.graph_util.import_graph_def` 如需更详细的加载代码,请参考 TensorFlow 官方文档或相关教程[^1][^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值