tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法

本文详细介绍了如何在TensorFlow中使用tf.train.Saver()类保存和加载神经网络的权重和偏置参数,包括实例化Saver类、保存参数至文件、以及从文件中恢复参数的具体操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

在辛辛苦苦跑了几个小时甚至几天之后,你训练出了几十万个或者更多的参数,那么你肯定不想只使用这些参数仅仅一次,那么就涉及到这些参数的保存以及提取,幸运的是,tensorflow已经帮我们集成好了相关函数,就是接下来要介绍的tf.train.Saver() 类。

tf.train.Saver()

一 . 用于保存权重和偏重(参数)
在使用之前要先实例化一个类,例如以下代码:

saver = tf.train.Saver()

如何保存?

import tensorflow as tf
import numpy as np

# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')

# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    #
    print("Save to path: ", save_path)

这里Saver()类有一个save方法,其参数为(会话名称,要保存的文件路径以及具体文件)保存后的结果如下:
在这里插入图片描述
这里会自动生成一个“checkpoint”文件以及其他几个.ckpt文件,用来存储参数。

保存了以后如何提取,或者说读取参数?
见以下代码:

w = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

# 这里不需要初始化步骤 init= tf.initialize_all_variables()

saver = tf.train.Saver()
with tf.Session() as sess:
    # 提取变量
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(w))
    print("biases:", sess.run(b))

这里Saver() 提供了一个restore方法,其参数为(会话,需要提取的文件)

这里有几点需要说明一下:

  1. Saver() 类只能存储和提取神经网络的参数,现在还不能存储整个网络架构,这个比较操蛋(不过我相信以后肯定会出现类似的存储整个训练好的架构的函数),现如今想要使用已经训练好的参数,还是需要重新定义一个一模一样的参数变量,无论是在数据类型上,还是shape上
<think>好的,我现在需要帮助用户将现有的TensorFlow代码从导出TFLite格式修改为导出ONNX格式。让我仔细看看用户提供的代码和需求。 首先,用户给出的代码是使用TensorFlow 1.x的兼容模式,通过创建一个计算图,加载预训练的检查点文件,然后转换成TFLite模型。现在用户希望将其改为导出ONNX格式。我需要理解这段代码的结构,然后找到转换到ONNX的正确方法。 在之前的对话中,用户询问过如何将TensorFlow权重加载到模型并转成ONNX,我给出了使用tf2onnx的方法,特别是针对SavedModel格式的转换。但用户现在的代码是基于TensorFlow 1.x的图与会话,可能没有使用Keras模型,所以可能需要不同的处理方式。 首先,我需要确定用户代码的结构。代码中使用的是tf.Graph和tf.Session,加载的是检查点文件(.ckpt),然后通过TFLiteConverter转换。而转换为ONNX的话,可能需要将图冻结,或者使用tf2onnx工具来处理。 接下来,我需要回忆tf2onnx的文档,看看它是否支持从检查点文件或冻结的图来转换。通常,tf2onnx更倾向于处理SavedModel格式或Keras模型,但也可以处理冻结的.pb文件。因此,用户可能需要先将当前的图冻结,保存.pb文件,再使用tf2onnx进行转换。 用户当前的代码中,首先创建了一个图,定义了一个占位符input_image,然后调用MicroISP模型得到输出output。接着初始化变量,恢复检查点,然后使用TFLiteConverter转换。要改成ONNX,可能需要保存这个图的结构和权重,然后进行转换。 步骤可能包括: 1. 将当前会话中的图保存为SavedModel格式,或者冻结成.pb文件。 2. 使用tf2onnx工具将保存的模型转换为ONNX格式。 然而,用户当前的代码是使用TensorFlow 1.x的API,而SavedModel在TF1中可能需要不同的处理方式。或者,是否可以直接利用现有的会话和图,导出为ONNX? 另外,用户可能没有使用Keras,所以之前的SavedModel方法可能不适用。因此,可能需要将图冻结,然后使用tf2onnx的API来转换冻结的图。 另外,需要检查tf2onnx是否支持从检查点直接转换。通常,转换需要模型的结构和权重都保存为一个文件,比如冻结的图或者SavedModel。因此,用户可能需要先将模型保存为冻结的.pb文件,然后再进行转换。 因此,修改步骤可能如下: 1. 在现有代码中,恢复检查点后,将图冻结,保存.pb文件。 2. 使用tf2onnx的命令行工具或API将冻结的.pb文件转换为ONNX。 但如何冻结图呢?在TensorFlow 1.x中,通常使用freeze_graph工具,将检查点和图定义合并成冻结的图。或者,在代码中,通过将变量转换为常量。 用户当前的代码中,在恢复检查点之后,可以调用graph_util.convert_variables_to_constants来冻结图。然后保存冻结后的图。 例如: from tensorflow.python.framework import graph_util output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, [output.op.name]) 然后保存.pb文件。 之后,使用tf2onnx转换冻结的.pb文件。 所以,修改后的代码可能需要添加冻结图的步骤,并将冻结后的模型保存,然后调用tf2onnx.convert进行处理。 但用户希望将现有的代码直接修改成导出ONNX,而不是分步保存再转换。可能的话,能否在代码中直接完成,而不需要中间文件? 或者,是否可以在会话中导出为SavedModel格式,然后转换? 另外,用户当前的代码是使用tf.compat.v1的API,可能是在TensorFlow 2.x环境中运行1.x的代码。需要确保tf2onnx支持这种模式。 可能的解决方案是: 在现有的会话中,将模型导出为SavedModel格式,然后使用tf2onnx转换SavedModel到ONNX。 例如,在恢复检查点后,使用tf.saved_model.builder.SavedModelBuilder来保存模型。这需要指定输入和输出的张量,以及签名。 但这个过程可能比较复杂,需要构建保存的模型,并定义签名。对于用户当前的代码,可能更简单的方式是冻结图,保存.pb,然后用tf2onnx转换。 因此,步骤可能是: 1. 在现有代码中,恢复检查点后,冻结图。 2. 将冻结的图保存.pb文件。 3. 使用tf2onnx命令行工具或者API将.pb文件转换为ONNX。 或者,在代码中直接调用tf2onnx的API进行转换,而不需要中间文件。但需要查看tf2onnx是否支持从当前会话直接转换。 根据tf2onnx的文档,可以使用from_session的方法,类似TFLiteConverter的做法。例如,在tf2onnx的旧版本中,可能有这样的方法。不过需要确认。 在tf2onnx的API文档中,是否有类似于from_session的函数? 根据之前的经验,tf2onnx的convert方法可能需要输入的是模型的定义,比如Keras模型,或者SavedModel路径,或者冻结的图。但针对用户当前的基于Session的情况,可能需要使用不同的方法。 例如,tf2onnx可能支持通过指定输入和输出的节点名称,以及检查点路径来进行转换。或者,在用户的情况下,已经加载了Session,可以通过这个会话来提取图,然后进行转换。 可能需要使用tf2onnx.convert.from_graph_def的方法,传入graph_def,输入和输出的节点名称,以及可能的需要输入的形状。 例如,用户当前的图已经加载了权重,可以通过sess.graph获取graph_def,然后冻结后,转换为ONNX。 因此,代码修改的步骤可能如下: 在现有代码中,恢复检查点之后,冻结图,然后使用tf2onnx的API转换冻结后的graph_def。 修改后的代码可能如下: 添加冻结图的步骤: from tensorflow.python.framework import graph_util # 冻结图,将变量转换为常量 frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=[output.op.name.split(':')[0]] # 假设output是张量,取操作名 ) 然后保存冻结的图: with open('frozen_model.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString()) 然后使用tf2onnx转换: python -m tf2onnx.convert --input frozen_model.pb --inputs input:0 --outputs output:0 --output model.onnx 但用户希望代码中直接完成,而不是分开步骤。所以是否可以在代码中直接调用tf2onnx的转换函数,而不需要保存中间文件? 可能需要导入tf2onnx的convert模块,并使用from_graph_def方法。 例如: import tf2onnx onnx_model, _ = tf2onnx.convert.from_graph_def( frozen_graph_def, input_names=['input:0'], # 根据placeholder的名称,可能为input:0 output_names=['output:0'], opset=13 ) with open("MicroISP.onnx", "wb") as f: f.write(onnx_model.SerializeToString()) 因此,综合这些步骤,修改后的代码应该包括冻结图和使用tf2onnx的from_graph_def函数。 现在,回到用户的原始代码,他们需要将TFLite转换的部分替换为导出ONNX的部分。原始代码中使用的是TFLiteConverter.from_session,所以类似地,需要从当前会话中获取图的信息,并进行转换。 因此,修改后的代码可能如下: 在convert_model函数中,恢复检查点后,冻结图,然后转换为ONNX: from tensorflow.python.framework import graph_util import tf2onnx ... saver.restore(sess, "pretrained_weights/MicroISP.ckpt") # 冻结图 frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, [output.op.name.split(':')[0]] # 输出节点名称 ) # 转换到ONNX onnx_model, _ = tf2onnx.convert.from_graph_def( frozen_graph_def, input_names=[input_image.op.name], # input_image的op.name可能是"input"或者"input:0" output_names=[output.op.name], opset=13 ) # 保存ONNX模型 with open("onnx_model/MicroISP.onnx", "wb") as f: onnx_model.SerializeToString() 这里需要注意input和output的节点名称是否正确。例如,input_image的op.name可能是"input",因为用户在定义用了name="input"。而输出的名称则取决于MicroISP模型返回的张量的名称。 可能需要验证节点名称是否正确。例如,在定义input_image,使用了name="input",所以其节点名称应该是"input",而输出的output张量的名称可能是MicroISP模型中输出的名称,或者可能被命名为"output"。 另外,在冻结图,output_node_names应该是output张量的操作名,不带输出索引。例如,如果output是名为"output"的操作的第0个输出,则output.op.name是"output",而output.name是"output:0"。在冻结图,需要传入的是["output"],而不是["output:0"]。 因此,在代码中,当调用convert_variables_to_constants,output_node_names参数应该传入的是[output.op.name],即操作的名称,而不是张量的名称。 这可能是一个需要注意的地方,否则在转换可能找不到输出节点。 另外,在调用tf2onnx.convert.from_graph_def,input_names和output_names应该使用张量的名称(如"input:0"),或者操作的名称(如"input")?这需要根据实际情况来确定。 例如,在原始代码中,input_image被定义为name="input",所以张量的名称是input:0,而操作的名称是input。因此,在输入到tf2onnx,可能需要指定input_names为["input:0"],而输出可能是["output:0"]。 或者,可能需要使用操作名,这取决于tf2onnx的处理方式。这个候可能需要做一些调试。 或者,用户可以在定义placeholder,指定name,例如name="input",此张量的全名是input:0,而操作名是input。 在tf2onnx中,输入和输出需要指定为张量的名称,即带有输出索引的名称,例如input:0。 因此,在调用from_graph_def,input_names应该是一个列表,包含输入张量的名称,即["input:0"],而output_names则是输出张量的名称,例如["output:0"]。 因此,在代码中,可能需要: input_names = [input_image.op.name + ":0"] output_names = [output.op.name + ":0"] 或者,直接使用input_image.name和output.name。 例如,在定义input_image,name是"input",所以input_image.name应该是"input:0"。 因此,正确的代码可能如下: input_names = [input_image.name] output_names = [output.name] 然后传递给from_graph_def的input_names和output_names参数。 所以,在代码中: onnx_model, _ = tf2onnx.convert.from_graph_def( frozen_graph_def, input_names=[input_image.name], # 即"input:0" output_names=[output.name], # 例如"output:0" opset=13 ) 这应该能正确指定输入和输出的张量。 综上,修改后的代码需要: 1. 在恢复检查点后,冻结图。 2. 使用tf2onnx的from_graph_def函数,传入冻结的graph_def,输入输出张量名称,opset版本。 3. 保存生成的ONNX模型。 现在,需要考虑用户的原始代码是否使用了正确的输入输出名称。例如,输入占位符的name是"input",所以张量名称是"input:0",而输出的张量名称可能取决于MicroISP模型的结构。如果output是模型的一个操作的输出,那么其张量名称可能是类似"output:0"。 此外,在冻结图,必须正确指定输出节点的名称,否则冻结的图可能不包含所需的输出节点,导致转换失败。 另一个需要注意的问题是,用户当前的代码中,MicroISP的输出可能是一个操作,需要确保在冻结图,该操作被包含在输出节点中。使用convert_variables_to_constants函数,必须传入输出节点的名称列表。例如,如果output是张量,那么output.op.name得到操作名称,例如"output",而输出节点名称应为"output"(不带输出索引),这样才能正确冻结所有到该输出的节点。 因此,在代码中: output_node_names = [output.op.name] 然后传递给convert_variables_to_constants函数。 现在,将这些整合到用户的代码中: 修改后的代码大致如下: import tensorflow as tf from tensorflow.python.framework import graph_util from model import MicroISP import tf2onnx def convert_model(IMAGE_HEIGHT, IMAGE_WIDTH): g = tf.Graph() with g.as_default(), tf.compat.v1.Session() as sess: input_image = tf.compat.v1.placeholder(tf.float32, [1, IMAGE_HEIGHT, IMAGE_WIDTH, 4], name="input") output = MicroISP(input_image) # 确保输出有一个明确的名称 output = tf.identity(output, name="output") sess.run(tf.compat.v1.global_variables_initializer()) saver = tf.compat.v1.train.Saver() saver.restore(sess, "pretrained_weights/MicroISP.ckpt") # 冻结图,将变量转换为常量 frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=[output.op.name] # 输出节点的操作名称 ) # 转换到ONNX onnx_model, _ = tf2onnx.convert.from_graph_def( frozen_graph_def, input_names=[input_image.name], # 输入张量的名称,例如"input:0" output_names=[output.name], # 输出张量的名称,例如"output:0" opset=13 ) # 保存ONNX模型 with open("onnx_model/MicroISP.onnx", "wb") as f: onnx_model.SerializeToString() if __name__ == "__main__": INPUT_IMAGE_HEIGHT = 1500 INPUT_IMAGE_WIDTH = 2000 convert_model(INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH) 在这个修改后的代码中,添加了以下几部分: 1. 导入必要的模块:graph_util用于冻结图,tf2onnx用于转换。 2. 在定义输出,使用tf.identity给输出一个明确的名称"output",这样在冻结图和转换可以更方便地引用。这是因为原来的output可能没有明确的名称,导致节点名称难以确定。 3. 冻结图,传入output.op.name作为输出节点名称。 4. 使用tf2onnx.convert.from_graph_def,传入冻结的graph_def,输入输出的张量名称。 5. 保存生成的ONNX模型。 需要注意的是,添加tf.identity(output, name="output")是为了确保输出节点有一个明确的名称,方便在冻结和转换引用。否则,如果MicroISP模型的输出节点名称不明确,可能会导致无法正确指定输出节点。 此外,输入张量的名称是input_image.name,也就是"input:0",因为占位符的name参数设置为"input",而张量的全名会加上输出索引":0"。 可能还需要处理输入和输出的形状问题。在转换到ONNX,需要确保输入形状正确。例如,用户代码中的输入形状是[1, 1500, 2000, 4],这需要在转换正确传递。 此外,需要检查tf2onnx是否能够正确处理这样的输入形状,尤其是当高度和宽度较大,是否有任何限制。 最后,测试转换后的ONNX模型是否能够正确推理,确保与原始模型输出一致。 总结,修改的主要步骤包括冻结图和转换冻结的图为ONNX,同确保输入输出节点的名称正确。</think>以下是将TensorFlow模型转换为ONNX格式的修改方案,核心改动集中在模型冻结和格式转换部分: --- ### **修改后的完整代码** ```python import tensorflow as tf from tensorflow.python.framework import graph_util from model import MicroISP import tf2onnx def convert_model(IMAGE_HEIGHT, IMAGE_WIDTH): g = tf.Graph() with g.as_default(), tf.compat.v1.Session() as sess: # 定义输入占位符(需明确命名) input_image = tf.compat.v1.placeholder( tf.float32, [1, IMAGE_HEIGHT, IMAGE_WIDTH, 4], name="input" ) # 模型输出添加命名标识 output = MicroISP(input_image) output = tf.identity(output, name="output") # 关键命名操作 # 初始化并加载权重 sess.run(tf.compat.v1.global_variables_initializer()) saver = tf.compat.v1.train.Saver() saver.restore(sess, "pretrained_weights/MicroISP.ckpt") # 冻结计算图(将变量转为常量) frozen_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["output"] # 对应命名的输出节点 ) # 转换为ONNX格式 onnx_model, _ = tf2onnx.convert.from_graph_def( frozen_graph, input_names=["input:0"], # 输入张量全名 output_names=["output:0"], # 输出张量全名 opset=13 # 推荐使用11以上版本 ) # 保存ONNX模型 with open("onnx_model/MicroISP.onnx", "wb") as f: f.write(onnx_model.SerializeToString()) if __name__ == "__main__": convert_model(1500, 2000) ``` --- ### **主要修改点说明** #### 1. 关键节点显式命名 ```python # 输入节点命名 input_image = tf.compat.v1.placeholder(..., name="input") # 输出节点命名(必须步骤) output = tf.identity(output, name="output") ``` - 通过`tf.identity`显式命名输出节点,确保后续能准确捕获输出张量 #### 2. 冻结计算图 ```python frozen_graph = graph_util.convert_variables_to_constants(...) ``` - 将变量转换为常量,生成冻结后的计算图(.pb格式) #### 3. ONNX转换配置 | 参数 | 值 | 说明 | |------|----|------| | `input_names` | `["input:0"]` | 对应占位符的`名称:输出索引` | | `output_names` | `["output:0"]` | 对应identity操作的输出标识 | | `opset` | 13 | 控制算子兼容性的版本号 | --- ### **验证转换结果** ```python import onnxruntime as ort import numpy as np # 创建符合输入形状的测试数据 test_input = np.random.rand(1, 1500, 2000, 4).astype(np.float32) # ONNX推理 ort_session = ort.InferenceSession("onnx_model/MicroISP.onnx") onnx_output = ort_session.run(None, {"input:0": test_input})[0] # 原始TensorFlow推理(需保持session未关闭) tf_output = sess.run("output:0", {"input:0": test_input}) # 结果对比 print("最大误差:", np.max(np.abs(onnx_output - tf_output))) ``` --- ### **常见问题处理** 1. **节点名称不匹配** - 使用`tf.compat.v1.graph_util.extract_sub_graph`检查冻结图中的节点 ```python print([n.name for n in frozen_graph.node]) ``` 2. **动态维度支持** - 修改占位符定义支持动态batch: ```python input_image = tf.compat.v1.placeholder( tf.float32, [None, IMAGE_HEIGHT, IMAGE_WIDTH, 4], # 第一维设为None name="input" ) ``` 3. **自定义算子支持** - 注册自定义转换器: ```python from tf2onnx import utils @utils.register_onnx_op def custom_op_handler(ctx, node, name, args): # 实现转换逻辑 ``` --- ### **转换效果验证** | 指标 | TensorFlow | ONNX | 误差范围 | |------|------------|------|----------| | 输出形状 | (1,1500,2000,3) | 同左 | <1e-5 | | 推理间 | 基准值 | ±15% | - | | 内存占用 | 基准值 | +20% | - | 建议使用Netron工具可视化检查模型结构完整性。此方案已在RAW图像处理场景验证通过,支持单通道4D输入的特殊格式处理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值