废话不多说直接上我遇到的问题。我在训练自己修改的Mobile net_V2时遇到一个问题在笔记本上自己训练好自己的模型能通过简单的模型调用方式(https://blog.youkuaiyun.com/qq_38644840/article/details/96475356类似第二种方式调用)。是可以直接运行的,但是最近在看PBT训练方式,因此我将模型和自己的数据移植到别人的训练框架下运行保存得到自己的模型。由于是嵌入到作者写好的框架下运行的,显然命名方式和我们是不一样的,因此调用模型的时候出现节点找不到的问题!当然我们可以修改我们模型的变量的名字,从而保证训练的模型与实际模型之间可以调用,但这方法需要我们修改模型中的一些参数名,很容易遗漏。因此我们可以比较模型中和我们调用的模型的参数命名之间的区别。这里我们可以通过下面的代码进行比较:
#这里我们可以看看模型的参数的键名(字典中的key)
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader('./models/model32x32/model.ckpt')
var_to_shape_map = reader.get_variable_to_shape_map()
我们通过变量的看的更加直观:
但是我们得到的模型名字不同:
由于是同一个模型所以尾缀是一样的但是前面多了一个"model/"的路径实际上有的变量甚至多了两个这样的前缀,
但是还是存在规律的就是尾缀一致,我们直接运行:
import tensorflow as tf
x=tf.placeholder(tf.float32,shape=[None,64,64,3])
y=mobilenetv2(x,num_classes=11,is_train=False)
tf.reset_default_graph()
saver = tf.train.import_meta_graph("./models/checkpoint_300/model.ckpt-300.meta")
with tf.Session() as sess:
saver.restore(sess,'./models/checkpoint_300/model.ckpt-300')
显然会遇到下面的问题:
接下来我们要做的就是去掉那烦人的前缀了:
import tensorflow as tf
import argparse
import os
parser = argparse.ArgumentParser(description='')
parser.add_argument("--checkpoint_path", default='./checkpoint_300/model.ckpt-300', help="restore ckpt") #原参数路径
parser.add_argument("--new_checkpoint_path", default='./New_model/', help="path_for_new ckpt") #新参数保存路径
##parser.add_argument("--add_prefix", default='deeplab_v2/', help="prefix for addition") #新参数名称中加入的前缀名
#
args = parser.parse_args()
#去掉前缀,这里你可以定义自己的函数获取新的变量名
def removemodel(name):
if name[:6]=='model/':
name=name[6:]
name=removemodel(name)
return name
#b=removemodel(a)
def main():
if not os.path.exists(args.new_checkpoint_path):
os.makedirs(args.new_checkpoint_path)
with tf.Session() as sess:
new_var_list=[] #新建一个空列表存储更新后的Variable变量
for var_name, _ in tf.contrib.framework.list_variables(args.checkpoint_path): #得到checkpoint文件中所有的参数(名字,形状)元组
var = tf.contrib.framework.load_variable(args.checkpoint_path, var_name) #得到上述参数的值
new_name = removemodel(var_name)
# new_name = args.add_prefix + new_name #在这里加入了名称前缀,大家可以自由地作修改
#除了修改参数名称,还可以修改参数值(var)
print('Renaming %s to %s.' % (var_name, new_name))
renamed_var = tf.Variable(var, name=new_name) #使用加入前缀的新名称重新构造了参数
new_var_list.append(renamed_var) #把赋予新名称的参数加入空列表
print('starting to write new checkpoint !')
saver = tf.train.Saver(var_list=new_var_list) #构造一个保存器
sess.run(tf.global_variables_initializer()) #初始化一下参数(这一步必做)
model_name = 'mobilenet' #构造一个保存的模型名称
checkpoint_path = os.path.join(args.new_checkpoint_path, model_name) #构造一下保存路径
saver.save(sess, checkpoint_path) #直接进行保存
print("done !")
if __name__ == '__main__':
main()
这段代码就能帮你改变模型中的变量名了,我这里是放在保存的模型的文件夹里面运行的(读者可以根据自己的实际情况修改路径)。嗯嗯,然后我们调用我们新的模型就能正常运行了。
好了,说了一大堆废话,总结一下:如果遇到Notfounde ERROR可以尝试修改模型的变量名。而无需从新该模型命名甚至从新训练!
参考:https://blog.youkuaiyun.com/jiongnima/article/details/86632517
https://blog.youkuaiyun.com/runningwei/article/details/85677793