tensorflow模型加载时,出现“Key Variable_xxx not found in checkpoint”的问题

本文解决了在使用TensorFlow加载模型时遇到的“KeyVariableNotFoundInCheckpoint”问题。通过调整tf.train.Saver中的var_list参数为tf.trainable_variables(),确保只有训练的变量被保存,避免了优化变量导致的加载错误。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

 

在使用tensorflow保存模型后,想要加载模型,用来对测试数据进行测试时,出现了“Key Variable not fonund in checkpoint"问题。

谷歌了一下,发现有一篇博客已经解答了这个问题,链接在参考博客上。为了看看我的问题和他的问题是不是一样的,我按照该博客的方法,在加载模型的代码中进行如下修改。

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

saver = tf.train.Saver()
reader = pywrap_tensorflow.NewCheckpointReader(model_checkpoint_path)#model_checkpoint_path是保存模型的路径加上模型名
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor name:",key)
    saver.restore(sess, model_checkpoint_path)

 输出的结果:

发现确实如这篇博客所说的,是因为模型把优化的变量也给保存下来了,所以后面就按照该博客建议的,在tf.train.Saver中指定真正要保存的变量。因为我的模型比较多层,按照这篇博客那种方法就太复杂了,所以我就直接把var_list设置成tf.trainable_variables()。

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

reader = pywrap_tensorflow.NewCheckpointReader(model_checkpoint_path)#model_checkpoint_path是保存模型的路径加上模型名
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor name:",key)
    saver.restore(sess, model_checkpoint_path)
saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt + 1 if FLAGS.keep_ckpt else 1000000, var_list = tf.trainable_variables()) #前面的max_to_keep是确认要保存的模型的个数,这个如果不指定的话可以直接去掉,只保留后面的var_list = tf.trainable_variables()

修改完成后,再次运行训练代码,将训练完成的模型保存下来。保存完成后,再次加载模型。本来以为已经大功告成了,没想到再次报错。

再次把模型打印出来,跟之前加载的模型进行对比发现:有些tensor的tensorname已经没有Adam这个名称了,但是有些还是有。仔细观察我发现,那些没有的层,都是我训练的时候设置成可训练的层。因为我在训练的时候,是将pretrained的模型加载进来,然后对pretrained的模型的后面几层进行训练的。

发现这个问题之后,我再在测试代码中,插入下面这行代码。

saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt + 1 if FLAGS.keep_ckpt else 1000000,var_list=tf.trainable_variables())

最后模型正常加载,可以正常进行测试了。

总结:

需要在训练和测试代码中都增加下面这段代码。

saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt + 1 if FLAGS.keep_ckpt else 1000000,var_list=tf.trainable_variables())

参考博客

https://blog.youkuaiyun.com/qq_36810544/article/details/81456182

<think>好的,我现在需要解决在加载模型出现的变量缺失问题,具体是TensorFlow或PyTorch的检查点中缺少变量“Generator_1/decoder/conv2_5/bias”。首先,我得理解这个问题的可能原因,然后找到对应的解决方案。 首先,变量缺失可能发生在模型保存和加载的过程中。可能的原因有几个:模型结构在保存和加载不一致,变量名称发生了变化,或者模型权重未正确保存。我需要分别针对TensorFlow和PyTorch来考虑解决方案。 在PyTorch中,当保存模型,通常使用`torch.save()`。如果保存的是整个模型(包括结构),加载也需要相同的结构定义。如果保存的是状态字典(state_dict),则需要在加载前实例化相同的模型结构,然后调用`load_state_dict()`。如果变量名称不匹配,比如在加载模型层的名称与保存的不同,就会导致Key Error。例如,用户提到的Key是“Generator_1/decoder/conv2_5/bias”,可能在保存该层的名称不同,或者在加载模型结构有变化,导致该层不存在或者名称改变。 解决方法可能包括:检查模型结构是否一致,确保加载模型与保存模型在层名称和结构上完全匹配。如果模型结构有变化,可能需要手动调整状态字典,删除或重命名某些键。或者使用`strict=False`参数来允许部分加载,忽略缺失的键。例如,在PyTorch中,加载可以设置`model.load_state_dict(torch.load(path), strict=False)`,这样即使某些键缺失也不会报错,但需要注意这可能导致部分权重未被加载,影响模型性能。 对于TensorFlow,情况类似。在保存模型,如果使用`tf.train.Checkpoint`或`model.save()`,加载需要模型结构一致。如果变量名称不匹配,可能需要检查自定义层的命名是否正确,或者在加载进行变量映射。TensorFlow 2.x的Keras API在保存模型,默认会保存模型结构,但如果是自定义层或模型,可能需要指定正确的名称或在加载传入自定义对象。如果使用`load_weights()`方法加载权重到模型中,必须确保模型结构完全匹配,否则会出现变量找不到的错误。此,可以尝试在加载前重建模型结构,确保所有层的名称和结构与保存一致。 此外,用户提供的引用中提到了Keras的保存和恢复模型的方法,引用[3]提到Keras有两种构建模型的方式,保存模型/权重和恢复需要注意自定义层、损失函数等。如果在模型中使用了自定义层,如自定义的卷积层,加载模型需要提供这些自定义对象的定义,否则会报错。例如,在加载模型使用`custom_objects`参数来指定自定义层或函数。 现在需要综合这些信息,给出具体的解决步骤。可能包括: 1. 检查模型结构是否在保存和加载一致,特别是变量名称是否有变化。 2. 对于PyTorch,使用`strict=False`加载并检查缺失的键,或者调整模型结构。 3. 对于TensorFlow,确保自定义层正确注册,或使用变量重命名。 4. 检查保存的模型文件是否完整,是否有损坏。 5. 如果模型结构确实有变化,考虑迁移学习或部分加载可用权重。 例如,在PyTorch中,如果加载提示缺少某个键,可以先打印当前模型的状态字典键和保存的检查点中的键,比较差异,然后通过修改状态字典的键名来匹配。或者,如果该变量在模型中已不存在,可以忽略该键。 在TensorFlow中,可能需要使用`tf.train.list_variables`来检查检查点中的变量列表,确认是否存在该变量。如果不存在,可能是因为模型结构改变,需要重新训练或调整结构。如果是名称问题,可以使用`tf.train.Checkpoint`的`var_list`参数进行变量映射。 此外,引用[2]中提到了PyTorch保存模型的例子,使用`torch.save`保存模型,而加载需要对应结构。如果保存的是整个模型,而加载结构不同,就会出现问题。因此,保存状态字典而非整个模型可能更灵活,但需要确保结构一致。 总结解决方案步骤: 对于PyTorch: 1. 检查模型定义是否与保存一致,确保所有层的名称匹配。 2. 加载使用`strict=False`,并处理缺失的键。 3. 手动调整状态字典的键名以匹配当前模型结构。 4. 如果缺失的变量不再需要,可以忽略;否则需要修改模型结构。 对于TensorFlow: 1. 检查模型结构是否一致,特别是自定义层的名称。 2. 使用`custom_objects`参数加载包含自定义层的模型。 3. 使用变量映射或重命名来匹配检查点中的变量。 4. 检查检查点中的变量列表,确认是否存在该变量,可能需要重新训练或修复模型结构。 最后,用户需要根据具体使用的框架,选择相应的解决方案,并仔细检查模型结构和变量名称,确保一致。</think>### 解决模型加载变量缺失问题的分步指南 #### 问题原因分析 加载模型出现变量缺失错误(如`KeyError: 'Generator_1/decoder/conv2_5/bias'`),通常由以下原因导致: 1. **模型结构不一致**:加载模型与保存模型结构或层名称不匹配。 2. **变量命名变化**:模型层在保存后被重命名或删除。 3. **自定义组件未注册**(TensorFlow):加载包含自定义层/函数的模型未声明`custom_objects`。 4. **检查点损坏**:保存的模型文件不完整或损坏。 --- ### PyTorch 解决方案 #### 步骤1:检查模型结构一致性 确保加载代码中的模型定义与保存**完全一致**,包括: - 层名称 - 子模块顺序 - 参数配置(如卷积核数量) #### 步骤2:灵活加载状态字典 使用`strict=False`允许部分加载,忽略缺失的变量: ```python model = Generator() # 实例化模型 state_dict = torch.load('model.pth') model.load_state_dict(state_dict, strict=False) ``` - **输出警告**会提示缺失的键(如`conv2_5.bias`)和意外的键(如新增层的参数)。 - 若缺失的键对模型性能影响较小(如某些偏置项),可忽略该警告。 #### 步骤3:手动调整键名映射 若层名称因重构改变(如`conv2_5`重命名为`conv3_5`),可修改状态字典键名: ```python new_state_dict = {} for key, value in state_dict.items(): new_key = key.replace('conv2_5', 'conv3_5') # 替换旧键名 new_state_dict[new_key] = value model.load_state_dict(new_state_dict) ``` #### 步骤4:重新训练或冻结部分层 若缺失变量为关键参数: - 重新设计模型结构,恢复缺失层。 - 冻结无关层,仅训练新增部分: ```python for param in model.parameters(): param.requires_grad = False # 冻结所有层 new_layer = nn.Conv2d(...) # 新增缺失层 new_layer.weight.requires_grad = True # 允许训练新层 ``` --- ### TensorFlow 解决方案 #### 步骤1:检查检查点变量列表 使用`tf.train.list_variables`查看检查点内容: ```python import tensorflow as tf variables = tf.train.list_variables('path/to/checkpoint') for var_name, shape in variables: print(f"Variable: {var_name}, Shape: {shape}") ``` - 确认`Generator_1/decoder/conv2_5/bias`是否存在。 #### 步骤2:加载注册自定义对象 若模型包含自定义层/函数,需通过`custom_objects`参数声明: ```python model = tf.keras.models.load_model( 'model.h5', custom_objects={'CustomLayer': CustomLayer} # 替换为实际自定义组件名称 ) ``` #### 步骤3:变量名称映射 若变量名称因版本更新变化,可通过`var_list`重命名: ```python checkpoint = tf.train.Checkpoint( conv2_5_bias=model.decoder.layers[4].bias # 将检查点中的变量映射到新名称 ) checkpoint.restore('path/to/checkpoint') ``` #### 步骤4:重建模型结构 若检查点中确实缺失变量: 1. 修改模型定义,补充缺失层: ```python class Decoder(tf.keras.Model): def __init__(self): super().__init__() self.conv2_5 = tf.keras.layers.Conv2D(..., name='conv2_5') # 确保名称一致 ``` 2. 重新训练模型,或仅加载可用参数: ```python model.load_weights('checkpoint.h5', by_name=True) # 按名称加载匹配的权重 ``` --- ### 通用验证方法 1. **比较参数量**:打印加载前后模型的参数量,确认是否一致。 - PyTorch: ```python print(sum(p.numel() for p in model.parameters())) ``` - TensorFlow: ```python model.summary() ``` 2. **前向传播测试**:输入随机数据,验证模型能否正常推理。 --- ### 相关问题 1. **如何在PyTorch中仅保存模型权重?** 使用`torch.save(model.state_dict(), 'weights.pth')`,加载需先实例化模型再调用`load_state_dict`[^2]。 2. **TensorFlow中如何修复因自定义层导致的加载错误?** 在`tf.keras.models.load_model`中通过`custom_objects`参数注册自定义层类[^3]。 3. **模型加载后性能下降如何排查?** 检查权重加载完整性,使用一致性验证方法(如输出特定层的激活值对比)。 [^1]: 引用PyTorch模型保存实现示例。 : 参考PyTorch模型保存代码片段。 : 来自Keras模型保存与加载的官方文档说明。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值