机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

1. tf.train.Saver()

  • tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
  • TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。
  • 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。
saver = tf.train.Saver()

 

TensorFlow的保存和恢复分为两种:

  • 保存和恢复变量
  • 保存和恢复模型


saver.save()保存模型

#举例:
保存一个训练好的手写数据集识别模型
保存在当前路径的net文件夹中

 1 import os
 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 3 import tensorflow as tf
 4 from tensorflow.examples.tutorials.mnist import input_data
 5 
 6 #载入数据集
 7 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
 8 
 9 #每个批次100张照片
10 batch_size = 100
11 #计算一个需要多少个批次
12 n_batch = mnist.train.num_examples // batch_size
13 
14 #定义两个placeholder
15 x = tf.placeholder(tf.float32, [None, 784])
16 y = tf.placeholder(tf.float32, [None, 10])
17 
18 #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
19 W = tf.Variable(tf.zeros([784, 10]))
20 b = tf.Variable(tf.zeros([10]))
21 prediction = tf.nn.softmax(tf.matmul(x, W) + b)
22 #代价函数
23 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
24 #使用梯度下降法
25 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
26 
27 #初始化变量
28 init = tf.global_variables_initializer()
29 
30 #结果存放在一个布尔型列表中
31 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
32 
33 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
34 
35 saver = tf.train.Saver()
36 
37 with tf.Session() as sess:
38     sess.run(init)
39     for epoch in range(11):
40         for batch in range(n_batch):
41             batch_xs, batch_ys = mnist.train.next_batch(batch_size)
42             sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
43         acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
44         print('Iter = ' + str(epoch) +', Testing Accuracy = ' + str(acc))
45     #保存模型
46     saver.save(sess, 'net/my_net.ckpt')
View Code
#保存路径中的文件为:
    checkpoint:保存当前网络状态的文件
    my_net.ckpt.data-00000-of-00001
    my_net.ckpt.index
    my_net.ckpt.meta:保存Graph结构的文件

#关于函数saver.save(),常用的参数就是前三个:
save(
    sess,  # 必需参数,Session对象
    save_path,  # 必需参数,存储路径
    global_step=None,  # 可以是Tensor, Tensor name, 整型数
    latest_filename=None,  # 协议缓冲文件名,默认为'checkpoint',不用管
    meta_graph_suffix='meta',  # 图文件的后缀,默认为'.meta',不用管
    write_meta_graph=True,  # 是否保存Graph
    write_state=True,  # 建议选择默认值True
    strip_default_attrs=False  # 是否跳过具有默认值的节点

 

saver.restore()加载已经训练好的模型

#举例:
通过加载刚才保存的训练好的手写数据集识别模型进行手写数据集的识别

 1 import os
 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 3 import tensorflow as tf
 4 from tensorflow.examples.tutorials.mnist import input_data
 5 
 6 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
 7 batch_size = 100
 8 n_batch = mnist.train.num_examples // batch_size
 9 
10 x = tf.placeholder(tf.float32, [None, 784])
11 y = tf.placeholder(tf.float32, [None, 10])
12 
13 W = tf.Variable(tf.zeros([784, 10]))
14 b = tf.Variable(tf.zeros([10]))
15 prediction = tf.nn.softmax(tf.matmul(x, W) + b)
16 
17 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
18 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
19 
20 init = tf.global_variables_initializer()
21 
22 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
23 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
24 
25 saver = tf.train.Saver()
26 
27 with tf.Session() as sess:
28     sess.run(init)
29     print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
30     saver.restore(sess, 'net/my_net.ckpt')
31     print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
View Code
#执行结果:
    
0.098
0.9178

#直接得到的准确率相当低,通过加载训练好的模型,识别准确率大大提升。

2. 下载google图像识别网络inception-v3并查看结构

模型背景:
  Inception(v3) 模型是Google 训练好的最新一个图像识别模型,我们可以利用它来对我们的图像进行识别。

 

下载地址:
  https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip

 

文件描述:

  • classify_image_graph_def.pb 文件就是训练好的Inception-v3模型。
  • imagenet_synset_to_human_label_map.txt是类别文件,包含人类标签和uid之间的映射的文件。
  • imagenet_2012_challenge_label_map_proto.pbtxt是包含类号和uid之间的映射的文件。

代码实现

 1 import os
 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 3 import tensorflow as tf
 4 import tarfile
 5 import requests
 6 
 7 #inception模型下载地址
 8 inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
 9 
10 #inception模型存放地址
11 inception_pretrain_model_dir = 'inception_model'
12 if not os.path.exists(inception_pretrain_model_dir):
13     os.makedirs(inception_pretrain_model_dir)
14 #获取文件名,以及文件路径
15 filename = inception_pretrain_model_url.split('/')[-1]
16 filepath = os.path.join(inception_pretrain_model_dir, filename)
17 
18 #下载模型
19 if not os.path.exists(filepath):
20     print('download: ', filename)
21     r = requests.get(inception_pretrain_model_url, stream=True)
22     with open(filepath, 'wb') as f:
23         for chunk in r.iter_content(chunk_size=1024):
24             if chunk:
25                 f.write(chunk)
26     print('finish: ', filename)
27 #解压文件
28 tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)
29 
30 #模型结构存放文件
31 log_dir = 'inception_log'
32 if not os.path.exists(log_dir):
33     os.makedirs(log_dir)
34 
35 #classify_image_graph_def.pb为google训练好的模型
36 inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
37 with tf.Session() as sess:
38     #创建一个图来存放google训练好的模型
39     with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
40         graph_def = tf.GraphDef()
41         graph_def.ParseFromString(f.read())
42         tf.import_graph_def(graph_def, name='')
43     #保存图的结构
44     writer = tf.summary.FileWriter(log_dir, sess.graph)
45     writer.close()
View Code
#在下载过程中,下的特别慢,不知道是网络原因还是什么
#程序总卡着不动
#所以我就手动下载压缩包并进行解压

下载结果

3. 使用inception-v3做各种图像的识别

#代码实现:

 1 import os
 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 3 import tensorflow as tf
 4 import numpy as np
 5 import re
 6 from PIL import Image
 7 import matplotlib.pyplot as plt
 8 
 9 #这部分是对标签号和类别号文件进行一个预处理
10 
11 class NodeLookup(object):
12     def __init__(self):
13         label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
14         uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
15         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
16     def load(self, label_lookup_path, uid_lookup_path):
17         #加载分类字符串n********对应分类名称的文件
18         proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
19         uid_to_human={}
20         #一行一行读取数据
21         for line in proto_as_ascii_lines:
22             #去掉换行符
23             line = line.strip('\n')
24             #按照‘\t’进行分割
25             parsed_items = line.split('\t')
26             #获取分类编号
27             uid = parsed_items[0]
28             #获取分类名称
29             human_string = parsed_items[1]
30             #保存编号字符串n********与分类名称的映射关系
31             uid_to_human[uid] = human_string
32 
33         #加载分类字符串n********对应分类编号1-1000的文件
34         proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
35         node_id_to_uid = {}
36         for line in proto_as_ascii:
37             if line.startswith('  target_class:'):
38                 #获取分类编号1-1000
39                 target_class = int(line.split(': ')[1])
40             if line.startswith('  target_class_string:'):
41                 #获取编号字符串nn********
42                 target_class_string = line.split(': ')[1]
43                 # 保存分类编号1-1000与编号字符串n********映射关系
44                 node_id_to_uid[target_class] = target_class_string[1:-2]
45         # 建立分类编号1-1000对应分类名称的映射关系
46         node_id_to_name = {}
47         for key, val in node_id_to_uid.items():
48             #获取分类名称
49             name = uid_to_human[val]
50             # 建立分类编号1-1000到分类名称的映射关系
51             node_id_to_name[key] = name
52         return node_id_to_name
53     # 传入分类编号1-1000返回分类名称
54     def id_to_string(self, node_id):
55         if node_id not in self.node_lookup:
56             return ''
57         return self.node_lookup[node_id]
58 
59 #创建一个图来存放google训练好的模型
60 
61 with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
62     graph_def = tf.GraphDef()
63     graph_def.ParseFromString(f.read())
64     tf.import_graph_def(graph_def, name='')
65 
66 with tf.Session() as sess:
67     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
68     #遍历目录
69     for root, dirs, files in os.walk('images/'):
70         for file in files:
71             #载入图片
72             image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
73             predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
74             predictions = np.squeeze(predictions)#把结果转为1维数据
75 
76             #打印图片路径及名称
77             image_path = os.path.join(root, file)
78             print(image_path)
79 
80             # 显示图片
81             img = Image.open(image_path)
82             plt.imshow(img)
83             plt.axis('off')
84             plt.show()
85 
86             #排序
87             top_k = predictions.argsort()[-5:][::-1]
88             node_lookup = NodeLookup()
89             for node_id in top_k:
90                 # 获取分类名称
91                 human_string = node_lookup.id_to_string(node_id)
92                 # 获取该分类的置信度
93                 score = predictions[node_id]
94                 print('%s(score = %.5f)' % (human_string, score))
95             print()
View Code


#执行结果:

images/1.jpg
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca(score = 0.87265)
badger(score = 0.00260)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens(score = 0.00205)
brown bear, bruin, Ursus arctos(score = 0.00102)
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus(score = 0.00099)

images/2.jpg
French bulldog(score = 0.94474)
bull mastiff(score = 0.00559)
pug, pug-dog(score = 0.00352)
Staffordshire bullterrier, Staffordshire bull terrier(score = 0.00165)
boxer(score = 0.00116)

images/3.jpg
zebra(score = 0.94011)
tiger, Panthera tigris(score = 0.00080)
pencil box, pencil case(score = 0.00066)
hartebeest(score = 0.00059)
tiger cat(score = 0.00042)

images/4.jpg
hare(score = 0.87019)
wood rabbit, cottontail, cottontail rabbit(score = 0.04802)
Angora, Angora rabbit(score = 0.00612)
wallaby, brush kangaroo(score = 0.00181)
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.00056)

images/5.jpg
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.95047)
marmot(score = 0.00265)
mongoose(score = 0.00217)
weasel(score = 0.00201)
mink(score = 0.00199)

 

转载于:https://www.cnblogs.com/guoruxin/p/10238018.html

### 回答1: 你好,我理解你是想通过 优快云 学习 C 知识,并且你也使用 ChatGPT 这种类型的 AI。请问你有什么问题需要我回答吗? 关于你提到的问题,我可以解释一下。tf.train.Saver()TensorFlow 中用于保存和恢复模型的类。它允许你将 TensorFlow 变量保存到磁盘,并从这些变量中恢复。通常在训练过程中,你需要定期保存模型以便稍后恢复或进行推断。tf.train.Saver() 的常见用法如下: ```python import tensorflow as tf # 声明 TensorFlow 变量 weights = tf.Variable(...) biases = tf.Variable(...) # 初始化 TensorFlow 变量 init_op = tf.global_variables_initializer() # 训练模型 # 创建 Saver 对象 saver = tf.train.Saver() # 保存模型 with tf.Session() as sess: sess.run(init_op) # 训练模型 saver.save(sess, 'my_model') # 保存模型到文件 'my_model' ``` 当你需要恢复模型时,可以使用 tf.train.Saver() 的 restore() 方法,如下所示: ```python # 创建 Saver 对象 saver = tf.train.Saver() # 加载模型 with tf.Session() as sess: saver.restore(sess, 'my_model') # 从文件 'my_model' 恢复模型 # 执行推断 ``` 希望这可以回答你的问题。如果你有其他问题,请随时提出。 ### 回答2: tf.train.Saver()是一个用于保存、加载以及管理模型参数的类。在TensorFlow中,我们通常使用tf.train.Saver()来保存和恢复模型的变量。 在TensorFlow中,模型的训练过程通常会生成一些变量,比如神经网络的权重和偏差。而tf.train.Saver()类提供了一种方法,可以将这些变量保存到文件中。通过调用tf.train.Saver().save()方法,可以将模型的变量保存在一个checkpoint文件中,以供将来使用。 除了保存模型变量,tf.train.Saver()还可以用于加载已保存的模型变量。通过调用tf.train.Saver().restore()方法,可以从checkpoint文件中载入模型的变量,并且将其赋值给指定的TensorFlow变量。这样,我们就可以在程序中使用这些已保存的模型变量,而无需重新训练模型。 另外,tf.train.Saver()还具备一些其他的功能,比如可以指定保存和加载的变量以及保存和恢复模型的过程是否应该包含模型的图结构。 总结起来,tf.train.Saver()是一个用于保存、加载和管理TensorFlow模型参数的类。它提供了保存和恢复模型变量的功能,可以确保模型的训练结果可以方便地在之后的使用中进行加载和重用。 ### 回答3tf.train.Saver()tensorflow中用于模型参数的保存和恢复的类。 在tensorflow中,模型参数通常是在训练过程中不断更新的,而为了保留训练过程中的模型参数,我们可以使用tf.train.Saver()类来保存这些参数。tf.train.Saver()类提供了保存和恢复模型的方法,可以将模型的参数保存到文件中,并在需要的时候恢复这些参数。 保存模型参数是通过调用tf.train.Saver()类的save()方法实现的。save()方法需要传入一个session和一个保存路径,表示将当前模型的参数保存到指定的路径下。保存的参数可以是全局变量、权重、偏置等等。 恢复模型参数是通过调用tf.train.Saver()类的restore()方法实现的。restore()方法需要传入一个session和一个保存路径,表示从指定的路径中恢复模型的参数。恢复参数时,tensorflow会自动判断模型的参数是否当前模型的参数匹配,如果匹配,则恢复参数;如果不匹配,则会抛出异常。 使用tf.train.Saver()类可以实现模型的断点续训。即在训练过程中,可以将当前的模型参数保存到文件中。如果训练过程中发生意外,可以在恢复训练时,加载之前保存的模型参数,从上一次中断的地方继续训练。 总之,tf.train.Saver()tensorflow中用于保存和恢复模型参数的重要工具,它提供了方便的接口,使得我们可以灵活地管理模型参数,实现模型的保存、恢复和断点续训。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值