tensorflow的tf.train.Saver()函数使用小技巧

本文详细解析了TensorFlow中Saver类的作用及使用方法,包括模型保存、恢复和微调过程中的关键技巧。特别介绍了如何在加载预训练模型时排除特定层,以及在保存模型时包含所有变量的策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tensorflow的saver是很重要的,不光在保存模型文件的时候用到,在微调网络的过程中,加载预训练模型的时候也会用到;下面就一些实际工程中遇到的问题做一些讲解。

  • Saver类
def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None)
 var_list: 变量的列表,如果为None,则默认的变量为图中可保存的变量
 max_to_keep: 允许保存的最多的模型个数,当超过这个数值时,后面的模型会替换掉之前保存的模型
  • 经常用到的方法
saver = tf.train.Saver()

saver.save(sess,
           save_path,
           global_step=None,
           latest_filename=None,
           meta_graph_suffix="meta",
           write_meta_graph=True,
           write_state=True)     # 用于保存训练的模型
           
 saver.recover_last_checkpoints(checkpoint_paths)     # 用于从最近一次的训练结果恢复模型
 
 saver.restore(sess, save_path)      # 加载模型,可以指定加载某个模型,不一定非得最近一次
  • 微调网络过程中遇到的问题
  1. 网络只需加载一部分预训练模型的权重怎么办;或者说网络中某些层的权重,预训练模型中没有。
    解决办法:在定义saver对象的时候,把网络中这些层排除掉即可。然后在用restore从预训练模型中加载权重时就不会报错了,网络其余层没有从预训练模型加载权重的就需要初始化啦。
# 比如我网络中Logits层在预训练模型中没有
# 指定加载某些变量的权重
all_vars = tf.trainable_variables()
var_to_skip = [v for v in all_vars if v.name.startswith('Logits')]
print("got pretrained model, var_to_skip:\n" + " \n".join([x.name for x in var_to_skip]))
var_to_restore = [v for v in all_vars if not (v.name.startswith('Logits'))]
saver = tf.train.Saver(var_to_restore, max_to_keep=20)
sess.run(tf.global_variables_initializer())     # 初始化其余层的变量
saver.restore(sess, pretrained_model)           # 利用saver.restore恢复指定层的权重
  1. 保存是时候还要用之前定义的saver吗?
    我们之前定义的saver为了正确加载预训练模型,是把网络中以‘Logits’开头的变量排除了的;所以,如果还用这个saver来save训练模型的话,模型中会没有‘Logits’层的权重的。
    解决办法:重新再定义一个包含网络全部变量的saver对象用于保存模型,一个图中可以定义多个saver对象哟~
请查阅以下卷积神经网络代码,给出尝试运行并改良后的结果import tensorflow as tf import numpy as np from captcha.image import ImageCaptcha from PIL import Image import random import matplotlib.pyplot as plt number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] def random_captcha_text(char_set=number + alphabet + ALPHABET, captcha_size=4): captcha_text = [] for i in range(captcha_size): c = random.choice(char_set) captcha_text.append(c) return captcha_text def gen_captcha_text_and_image(i=0): # 创建图像实例对象 image = ImageCaptcha() # 随机选择4个字符 captcha_text = random_captcha_text() # array 转化为 string captcha_text = ''.join(captcha_text) # 生成验证码 captcha = image.generate(captcha_text) if i % 100 == 0: image.write(captcha_text, "D:/justin/captcha/image/" + captcha_text + '.jpg') captcha_image = Image.open(captcha) captcha_image = np.array(captcha_image) return captcha_text, captcha_image def convert2gray(img): if len(img.shape) > 2: gray = np.mean(img, -1) return gray else: return img # 文本转向量 def text2vec(text): text_len = len(text) if text_len > MAX_CAPTCHA: raise ValueError('验证码最长4个字符') vector = np.zeros(MAX_CAPTCHA * CHAR_SET_LEN) def char2pos(c): if c == '_': k = 62 return k k = ord(c) - 48 if k > 9: k = ord(c) - 55 if k > 35: k = ord(c) - 61 if k > 61: raise ValueError('No Map') return k for i, c in enumerate(text): idx = i * CHAR_SET_LEN + char2pos(c) vector[idx] = 1 return vector # 向量转回文本 def vec2text(vec): char_pos = vec[0] text = [] for i, c in enumerate(char_pos): char_idx = c % CHAR_SET_LEN if char_idx < 10: char_code = char_idx + ord('0') elif char_idx < 36: char_code = char_idx - 10 + ord('A') elif char_idx < 62: char_code = char_idx - 36 + ord('a') elif char_idx == 62: char_code = ord('_') else: raise ValueError('error') text.append(chr(char_code)) return "".join(text) # 生成一个训练batch def get_next_batch(batch_size=64): batch_x = np.zeros([batch_size, IMAGE_HEIGHT * IMAGE_WIDTH]) batch_y = np.zeros([batch_size, MAX_CAPTCHA * CHAR_SET_LEN]) def wrap_gen_captcha_text_and_image(i): while True: text, image = gen_captcha_text_and_image(i) if image.shape == (60, 160, 3): return text, image for i in range(batch_size): text, image = wrap_gen_captcha_text_and_image(i) image = convert2gray(image) batch_x[i, :] = image.flatten() / 255 batch_y[i, :] = text2vec(text) return batch_x, batch_y # 定义CNN def crack_captcha_cnn(w_alpha=0.01, b_alpha=0.1): x = tf.reshape(X, shape=[-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1]) w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 32])) b_c1 = tf.Variable(b_alpha * tf.random_normal([32])) conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1)) conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') conv1 = tf.nn.dropout(conv1, rate=1 - keep_prob) w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 32, 64])) b_c2 = tf.Variable(b_alpha * tf.random_normal([64])) conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2)) conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') conv2 = tf.nn.dropout(conv2, rate=1 - keep_prob) w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 64, 64])) b_c3 = tf.Variable(b_alpha * tf.random_normal([64])) conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3)) conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') conv3 = tf.nn.dropout(conv3, rate=1 - keep_prob) w_d = tf.Variable(w_alpha * tf.random_normal([8 * 20 * 64, 1024])) b_d = tf.Variable(b_alpha * tf.random_normal([1024])) dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]]) dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d)) dense = tf.nn.dropout(dense, rate=1 - keep_prob) w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN])) b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN])) out = tf.add(tf.matmul(dense, w_out), b_out) return out # 训练 def train_crack_captcha_cnn(): output = crack_captcha_cnn() loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=Y)) # 计算损失 optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) # 计算梯度 predict = tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN]) # 目标预测 max_idx_p = tf.argmax(predict, 2) # 目标预测最大值 max_idx_l = tf.argmax(tf.reshape(Y, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2) # 真实标签最大值 correct_pred = tf.equal(max_idx_p, max_idx_l) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # 准确率 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) step = 0 while True: batch_x, batch_y = get_next_batch(64) _, loss_ = sess.run([optimizer, loss], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.75}) print(step, loss_) if step % 100 == 0: batch_x_test, batch_y_test = get_next_batch(100) acc = sess.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test, keep_prob: 1.}) print(step, "准确率:",acc) if acc > 0.85: saver.save(sess, "D:/justin/captcha/model/85", global_step=step) step += 1 def crack_captcha(captcha_image, output): saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) # 获取训练后的参数 checkpoint = tf.train.get_checkpoint_state("model") if checkpoint and checkpoint.model_checkpoint_path: saver.restore(sess, checkpoint.model_checkpoint_path) print("Successfully loaded:", checkpoint.model_checkpoint_path) else: print("Could not find old network weights") predict = tf.argmax(tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2) text_list = sess.run(predict, feed_dict={X: [captcha_image], keep_prob: 1}) text = vec2text(text_list) return text if __name__ == '__main__': train = 1 # 0: 训练 1: 预测 if train == 0: number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] text, image = gen_captcha_text_and_image() print("验证码图像channel:", image.shape) # 图像大小 IMAGE_HEIGHT = 60 IMAGE_WIDTH = 160 MAX_CAPTCHA = len(text) print("验证码文本最长字符数", MAX_CAPTCHA) # 文本转向量 char_set = number + alphabet + ALPHABET + ['_'] # 如果验证码长度小于4, '_'用来补齐 CHAR_SET_LEN = len(char_set) # placeholder占位符 X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH]) Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN]) keep_prob = tf.placeholder(tf.float32) train_crack_captcha_cnn() # 预测时需要将训练的变量初始化 if train == 1: # 自然计数 step = 0 # 正确预测计数 rightCnt = 0 # 设置测试次数 count = 10 number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] IMAGE_HEIGHT = 60 IMAGE_WIDTH = 160 char_set = number + alphabet + ALPHABET + ['_'] CHAR_SET_LEN = len(char_set) MAX_CAPTCHA = 4 # placeholder占位符 X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH]) Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN]) keep_prob = tf.placeholder(tf.float32) output = crack_captcha_cnn() saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 获取训练后参数路径 checkpoint = tf.train.get_checkpoint_state("model") if checkpoint and checkpoint.model_checkpoint_path: saver.restore(sess, checkpoint.model_checkpoint_path) print("Successfully loaded:", checkpoint.model_checkpoint_path) else: print("Could not find old network weights.") while True: text, image = gen_captcha_text_and_image() f = plt.figure() ax = f.add_subplot(111) ax.text(0.1, 0.9,text, ha='center', va='center', transform=ax.transAxes) plt.imshow(image) plt.show() image = convert2gray(image) image = image.flatten() / 255 predict = tf.math.argmax(tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2) text_list = sess.run(predict, feed_dict={X: [image], keep_prob: 1}) predict_text = vec2text(text_list) predict_text = crack_captcha(image, output) print("step:{} 真实值: {} 预测: {} 预测结果: {}".format(str(step), text, predict_text, "正确" if text.lower() == predict_text.lower() else "错误")) if text.lower() == predict_text.lower(): rightCnt += 1 if step == count - 1: print("测试总数: {} 测试准确率: {}".format(str(count), str(rightCnt / count))) break step += 1
最新发布
06-20
### 回答1: 你好,我理解你是想通过 优快云 学习 C 知识,并且你也使用 ChatGPT 这种类型的 AI。请问你有什么问题需要我回答吗? 关于你提到的问题,我可以解释一下。tf.train.Saver()TensorFlow 中用于保存和恢复模型的类。它允许你将 TensorFlow 变量保存到磁盘,并从这些变量中恢复。通常在训练过程中,你需要定期保存模型以便稍后恢复或进行推断。tf.train.Saver() 的常见用法如下: ```python import tensorflow as tf # 声明 TensorFlow 变量 weights = tf.Variable(...) biases = tf.Variable(...) # 初始化 TensorFlow 变量 init_op = tf.global_variables_initializer() # 训练模型 # 创建 Saver 对象 saver = tf.train.Saver() # 保存模型 with tf.Session() as sess: sess.run(init_op) # 训练模型 saver.save(sess, 'my_model') # 保存模型到文件 'my_model' ``` 当你需要恢复模型时,可以使用 tf.train.Saver()restore() 方法,如下所示: ```python # 创建 Saver 对象 saver = tf.train.Saver() # 加载模型 with tf.Session() as sess: saver.restore(sess, 'my_model') # 从文件 'my_model' 恢复模型 # 执行推断 ``` 希望这可以回答你的问题。如果你有其他问题,请随时提出。 ### 回答2: tf.train.Saver()是一个用于保存、加载以及管理模型参数的类。在TensorFlow中,我们通常使用tf.train.Saver()来保存和恢复模型的变量。 在TensorFlow中,模型的训练过程通常会生成一些变量,比如神经网络的权重和偏差。而tf.train.Saver()类提供了一种方法,可以将这些变量保存到文件中。通过调用tf.train.Saver().save()方法,可以将模型的变量保存在一个checkpoint文件中,以供将来使用。 除了保存模型变量,tf.train.Saver()还可以用于加载已保存的模型变量。通过调用tf.train.Saver().restore()方法,可以从checkpoint文件中载入模型的变量,并且将其赋值给指定的TensorFlow变量。这样,我们就可以在程序中使用这些已保存的模型变量,而无需重新训练模型。 另外,tf.train.Saver()还具备一些其他的功能,比如可以指定保存和加载的变量以及保存和恢复模型的过程是否应该包含模型的图结构。 总结起来,tf.train.Saver()是一个用于保存、加载和管理TensorFlow模型参数的类。它提供了保存和恢复模型变量的功能,可以确保模型的训练结果可以方便地在之后的使用中进行加载和重用。 ### 回答3: tf.train.Saver()tensorflow中用于模型参数的保存和恢复的类。 在tensorflow中,模型参数通常是在训练过程中不断更新的,而为了保留训练过程中的模型参数,我们可以使用tf.train.Saver()类来保存这些参数。tf.train.Saver()类提供了保存和恢复模型的方法,可以将模型的参数保存到文件中,并在需要的时候恢复这些参数。 保存模型参数是通过调用tf.train.Saver()类的save()方法实现的。save()方法需要传入一个session和一个保存路径,表示将当前模型的参数保存到指定的路径下。保存的参数可以是全局变量、权重、偏置等等。 恢复模型参数是通过调用tf.train.Saver()类的restore()方法实现的。restore()方法需要传入一个session和一个保存路径,表示从指定的路径中恢复模型的参数。恢复参数时,tensorflow会自动判断模型的参数是否与当前模型的参数匹配,如果匹配,则恢复参数;如果不匹配,则会抛出异常。 使用tf.train.Saver()类可以实现模型的断点续训。即在训练过程中,可以将当前的模型参数保存到文件中。如果训练过程中发生意外,可以在恢复训练时,加载之前保存的模型参数,从上一次中断的地方继续训练。 总之,tf.train.Saver()tensorflow中用于保存和恢复模型参数的重要工具,它提供了方便的接口,使得我们可以灵活地管理模型参数,实现模型的保存、恢复和断点续训。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值