Tensorflow Saver & restore 以及报错问题 NotFoundError: "x_x" not found in checkpoint

本文介绍了使用莫烦Python教程中的方法保存与加载TensorFlow模型参数的过程及遇到的问题。主要探讨了因变量名称冲突导致的加载失败错误,并提出了两种解决办法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

根据 莫烦Python 的教程 保存模型参数

# -*- coding: utf-8 -*-
"""
Created on Tue Mar 20 10:01:04 2018

@author: lyh
"""

import tensorflow as tf 
import numpy as np
#save

W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weight")

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess,"mynet/save_net.ckpt")
    print("Save to path:" + save_path)


#load
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weight")
#不需要init
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess,save_path)
    print("weight:" + sess.run(W))

然后出现以下错误:NotFoundError: Tensor name “weight_1” not found in checkpoint files mynet/save_net.ckpt
这里写图片描述

浏览网上相关讨论,尝试 restart kernel (Spyder 编辑器)以及 更改 tf.train.Saver(write_version=tf.train.SaverDef.V1) 为 V1 版本均无法正常 restore。

原因
真正的原因是,我写的代码 保存和加载 在前后进行,在前后两次定义了

W = tf.Variable(xxx,name="weight")

相当于 在TensorFlow 图的堆栈创建了两次 name = “weight” 的变量,第二个(第n个)的实际 name 会变成 “weight_1” (“weight_n-1”),之后我们在保存 checkpoint 中实际搜索的是 “weight_n-1” 这个变量 而不是 “weight” ,因此就会出错。

解决方案:
(1)在加载过程中,定义 name 相同的变量前面加
tf.reset_default_graph() 清除默认图的堆栈,并设置全局图为默认图
(2)正常场景下,不会保存模型之后,马上加载(或在同一程序中加载),就不会出现这个情况,或者保存完之后 restart kernel (Spyder 中),再进行参数加载。

<think>嗯,用户让我详细解释这段代码:def load(self, saver, ckpt_dir): ... 看起来这是一个用于加载TensorFlow模型检查点的方法。首先,我需要理解代码的每一部分功能。 代码里首先用tf.train.get_checkpoint_state获取检查点状态,这可能用来确认ckpt_dir中是否有有效的检查点。然后检查ckpt和model_checkpoint_path是否存在,如果存在的话,获取最新的检查点路径。接着尝试从路径中提取global_step,这里用split分割路径,可能会遇到异常,所以用了try-except。之后用saver.restore恢复模型,返回True和global_step。否则打印失败信息,返回False和0。 用户可能想了解这段代码的作用、每个步骤的意义,以及潜在的问题。需要注意用户之前问过保存模型,现在转向加载,可能是在模型训练和部署中遇到问题,比如加载失败或者步骤数不正确。 需要解释tf.train.get_checkpoint_state的作用,它读取checkpoint文件,里面记录了最新的检查点路径。然后tf.train.latest_checkpoint是更直接的方法获取最新检查点,但用户代码里先获取ckpt.model_checkpoint_path,再调用latest_checkpoint,可能有些重复,或者是为了兼容性? 另外,提取global_step的方式是否正确?比如路径可能是“model-1000”,分割后得到1000,但如果路径中有其他数字,可能会出错,导致global_step为None,这里需要指出潜在问题。 还有异常处理部分,捕获ValueError,但除了打印失败信息,返回的global_step是None还是0?用户代码里返回的是0,这可能有问题,因为如果无法解析,应该如何处理? 最后,saver.restore的正确使用,是否需要确保saver和模型结构匹配,比如之前保存时用了正确的变量作用域,否则加载会失败。用户之前的问题提到了变量作用域的问题,这里可能需要关联起来,指出加载时也需要同样的作用域结构。 可能用户在实际使用中遇到了加载失败的情况,或者global_step不正确,导致训练重启时步骤数不对。需要详细解释代码的每个部分,指出潜在风险,并给出改进建议,比如使用更可靠的方法提取global_step,或者检查检查点文件是否存在。 总结,我需要分步骤解释代码功能,指出可能的问题,比如路径处理、异常处理、global_step提取方法,以及确保saver正确初始化的建议。同时,结合之前的问题,强调作用域一致性的重要性,避免变量不匹配。</think>--- ### 🔍 逐行代码解析(TensorFlow 1.x 模型加载逻辑) #### ⚙️ **方法定义** ```python def load(self, saver, ckpt_dir): ``` - **功能**:定义模型加载方法 - **参数**: - `saver`:`tf.train.Saver` 实例(需与保存时创建的 Saver 匹配) - `ckpt_dir`:检查点文件存储目录(如 `"model_ckpt/"`) --- #### 🕵️ **检查点状态探测** ```python ckpt = tf.train.get_checkpoint_state(ckpt_dir) ``` - **作用**:读取检查点目录中的 `checkpoint` 文件(该文件记录最新检查点路径) - **典型输出**: ```python model_checkpoint_path: "my_model-1000" all_model_checkpoint_paths: "my_model-1000" ``` --- #### 🚦 检查有效性 ```python if ckpt and ckpt.model_checkpoint_path: ``` - **逻辑**:双重验证检查点是否存在且路径有效 - **常见问题**:若目录中没有实际检查点文件,`ckpt.model_checkpoint_path` 可能为 `None` --- #### 🔍 获取最新检查点路径 ```python full_path = tf.train.latest_checkpoint(ckpt_dir) ``` - **替代方案**:可直接使用 `ckpt.model_checkpoint_path`(更推荐直接使用 `latest_checkpoint`) --- #### 🔢 解析训练步数 ```python try: global_step = int(full_path.split('/')[-1].split('-')[-1]) except ValueError: global_step = None ``` - **路径示例**:`"model_ckpt/my_model-1000"` ➔ `global_step = 1000` - **风险**: - 若路径包含其他数字(如 `model_v2-1000`),可能解析错误 - 更可靠的方法:在保存时显式传递 `global_step` 参数 --- #### 🎯 模型恢复核心操作 ```python saver.restore(self.sess, full_path) ``` - **关键点**: 1. **Saver 匹配**:需与保存时的 `tf.train.Saver()` 配置一致(变量列表/作用域相同) 2. **Session 一致性**:`self.sess` 必须包含与保存时相同的计算图结构 --- #### 🚨 错误处理 ```python else: print("[*] Failed to load model from %s" % ckpt_dir) return False, 0 ``` - **常见失败原因**: - 检查点目录路径错误 - 检查点文件被手动删除 - 计算图结构变化导致变量不匹配 --- ### 🛠️ **代码优化建议** #### 1. 全局步数解析改进 ```python # 使用正则表达式提取步数(更健壮) import re match = re.search(r'(\d+)$', full_path) # 匹配路径末尾的数字 global_step = int(match.group(1)) if match else None ``` #### 2. 添加图结构校验 ```python # 加载前校验计算图与检查点是否匹配 from tensorflow.python.tools import inspect_checkpoint as chkp chkp.print_tensors_in_checkpoint_file(full_path, tensor_name='', all_tensors=False) ``` #### 3. 异常捕获增强 ```python try: saver.restore(self.sess, full_path) except tf.errors.NotFoundError as e: print(f"变量不匹配错误: {e}") return False, 0 ``` --- ### 🌰 **完整调用示例** ```python # 创建 Saver 实例(必须与保存时的变量作用域一致) saver = tf.train.Saver() # 创建 Session 并加载 with tf.Session() as sess: loader = ModelLoader(sess) # 假设 self 是某个类的实例 success, step = loader.load(saver, "model_ckpt/") if success: print(f"成功加载模型,训练步数: {step}") ``` --- ### 📌 **关键要点总结** | 代码段 | 核心作用 | 常见陷阱 | |-----------------------|---------------------------------|--------------------------------| | `get_checkpoint_state` | 验证检查点目录有效性 | 目录中缺少 `checkpoint` 文件 | | `latest_checkpoint` | 获取最新检查点路径 | 路径格式不规范导致解析失败 | | `saver.restore` | 恢复模型参数 | 计算图结构变化导致变量不匹配 | --- 若遇到加载失败,建议优先检查: 1. **检查点文件完整性**:确认 `.index`, `.data`, `.meta` 文件同时存在 2. **作用域一致性**:加载时模型构建代码需与保存时保持相同变量作用域结构 3. **TensorFlow 版本**:1.x/2.x 的模型格式不兼容
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值