在使用TensorFlow时,连续调用saver.restore(sess, save_path)出错!!!!

  • 如果存在连续调用saver.restore(sess, save_path),都应该正确地重置图和会话状态,以避免可能的冲突。因此,在调用saver.restore(sess, save_path)之后,可以使用 tf.reset_default_graph() 和 tf.Session() 来清除之前的图和会话状态。比如:
    saver = tf.compat.v1.train.Saver()
    sess = tf.compat.v1.Session()
    sess.run(tf.compat.v1.global_variables_initializer())
    save_path = "./model/model"
    saver.restore(sess, save_path)
    correct_prediction = tf.argmax(y, 1)
    result = sess.run(correct_prediction, feed_dict={x: [img1, img2, img3, img4]})
    tf.compat.v1.reset_default_graph()##用于来清除之前的图和会话状态
    tf.compat.v1.Session()##来清除之前的图和会话状态
    sess.close()

找了好久才找到这个解决方法,希望对大家有用!!!

import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
05-26
<think>好的,用户现在需要将他们的TensorFlow v1代码迁移到v2。之前的对话中,他们已经遇到了兼容性问题,特别是关于graph_util.convert_variables_to_constants的问题。现在他们提供了一个具体的代码片段,希望将其转换为使用TensorFlow v2。 首先,我需要分析用户提供的代码。看起来他们正在尝试转换一个模型到TFLite格式,代码中使用TensorFlow v1的兼容模式,比如tf.compat.v1.Session、placeholder和Saver。这些都是v1的特性,在v2中已经不推荐使用。 用户的目标是将这段代码迁移到v2,这意味着需要替换掉所有v1特有的API。首先,应该检查代码中使用的v1组件,并找到对应的v2方法。例如,v2中推荐使用tf.function和跟踪函数来处理模型,而不是Session和placeholder。 接下来,用户可能不熟悉v2的具体变化,比如如何保存和加载模型,以及如何转换到TFLite。在v2中,推荐使用SavedModel格式,然后使用tf.lite.TFLiteConverter.from_saved_model,而不是通过Session来转换。此外,变量的初始化和保存也需要调整,可能使用Checkpoint或SavedModel的方式。 另外,用户之前遇到的错误是关于convert_variables_to_constants的缺失,这可能是因为他们在v2中尝试使用v1的函数。所以在新代码中,必须确保所有函数都是v2兼容的。 还需要考虑用户可能没有提到的其他问题,比如是否有自定义层或操作在转换中可能不兼容,或者是否有特定的输入输出需求。例如,输入形状是否正确,是否需要使用具体函数(concrete function)来保存模型。 然后,我应该逐步替换代码中的v1部分。例如,移除Session,使用tf.function装饰器来定义模型,用tf.keras.Input代替placeholder,使用SavedModel保存模型,再转换为TFLite。同,需要处理预训练权重的加载,可能需要使用v2的检查点机制。 还要注意用户提到的TFLite转换部分,原来的代码使用from_session,这在v2中不再适用,应该使用from_concrete_functions或from_saved_model。另外,设置experimental_new_converter可能已经默认启用,但可能需要确认。 最后,测试代码是否能在v2环境中运行,确保没有遗留的v1 API,并且生成的TFLite模型能够正确工作。同,需要提醒用户可能遇到的常见问题,比如输入输出名称的变化,或者需要指定具体函数的输入签名。 总结一下,步骤可能包括:替换Session和placeholder为tf.function和Keras Input,使用SavedModel保存,调整TFLite转换器,处理权重加载,并确保所有API调用都是v2兼容的。同,提供解释帮助用户理解每个改动的原因,以及可能的备选方案或注意事项。</think>我将为您提供分步骤的TensorFlow v2代码改造方案,并标注关键修改点: ```python import tensorflow as tf from model import MicroISP def convert_model(IMAGE_HEIGHT, IMAGE_WIDTH): # ================================ # 主要修改点1:移除Graph/Session架构 # ================================ # 创建Keras输入层替代placeholder input_image = tf.keras.Input(shape=[IMAGE_HEIGHT, IMAGE_WIDTH, 4], name="input") # 构建模型 model = tf.keras.Model( inputs=input_image, outputs=MicroISP()(input_image) ) # ================================ # 主要修改点2:使用v2的权重加载方式 # ================================ # 加载checkpoint(假设是v1格式ckpt) checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore("pretrained_weights/MicroISP.ckpt").expect_partial() # ================================ # 主要修改点3:TFLite转换新API # ================================ # 定义concrete function @tf.function(input_signature=[ tf.TensorSpec(shape=[1, IMAGE_HEIGHT, IMAGE_WIDTH, 4], dtype=tf.float32) ]) def serving_fn(inputs): return model(inputs) # 转换到TFLite converter = tf.lite.TFLiteConverter.from_concrete_functions( [serving_fn.get_concrete_function()]) # 优化设置(根据需求调整) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite_model = converter.convert() # 保存模型 with open("tflite_model/MicroISP.tflite2", "wb") as f: f.write(tflite_model) if __name__ == "__main__": INPUT_IMAGE_HEIGHT = 1500 INPUT_IMAGE_WIDTH = 2000 convert_model(INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH) ``` 关键修改说明: 1. **架构迁移** - 移除`tf.Graph`和`tf.Session`(v2默认启用eager execution) - 使用`tf.keras.Model`构建模型(替代原始函数式调用) - 用`tf.keras.Input`替代`tf.placeholder` 2. **权重加载** - 使用`tf.train.Checkpoint`加载旧版ckpt文件 - `expect_partial()`处理可能存在的变量不匹配 3. **TFLite转换** - 使用`@tf.function`定义具体函数 - `from_concrete_functions()`替代旧的`from_session()` - 移除显式的`experimental_new_converter`设置(v2默认启用) 4. **兼容性增强** - 显式指定输入签名`input_signature` - 添加常用优化配置(可根据实际需求调整) 验证建议: ```python # 测试加载转换后的模型 interpreter = tf.lite.Interpreter(model_path="tflite_model/MicroISP.tflite2") interpreter.allocate_tensors() # 获取输入/输出详情 input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() print("Input shape:", input_details[0]['shape']) print("Output shape:", output_details[0]['shape']) ``` 备选方案:若需保持原始模型结构,可改用SavedModel格式转换: ```python # 保存为SavedModel tf.saved_model.save(model, "saved_model_dir") # 转换方式 converter = tf.lite.TFLiteConverter.from_saved_model("saved_model_dir") ``` 迁移注意事项: 1. 如果`MicroISP`类包含v1特有操作(如`variable_scope`),需要修改为v2风格 2. 输入尺寸`[1, H, W, 4]`中的batch维度在转换可能需要显式处理 3. 建议使用`tf.__version__ >= 2.8`以保证最佳兼容性 如遇到具体错误,可提供完整报错信息以便进一步诊断。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

萝卜变潮人

感谢您的鼓励与支持!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值