下载tensorfow-master
https://github.com/tensorflow/tensorflow
解压到相应目录会发现tensorflow-master\tensorflow\examples\image_retraining并没有想要的retrain.py等文件
在readme.md发现代码已经被转移到了https://github.com/tensorflow/hub/tree/master/examples/image_retraining
下载数据集
先去http://www.robots.ox.ac.uk/~vgg/data/网站下载一些数据集并放入自己的retrain文件夹里
写bat文件
- 第一行写上retrain.py的路径
- ^是没有换行的意思
- bottleneck_dir为瓶颈路径(谷歌把这个inception-v3网络比作一个花瓶,大概是pool_3的位置),把每张图片对应的值保存到bottleneck里
- how_many_training_steps为训练的周期
- model_dir为模型文件路径
- output_graph为输出的训练好的模型到当前目录
- output_labels为输出标签
- image_dir为传进来的需要分类的图片路径(最好不要中文和大写字母)
(多次运行要记得删所在盘符的tmp文件)
retrain.bat文件内容如下:
python C:/Users/Administrator/TensorFlow/tensorflow-master/tensorflow/examples/image_retraining/retrain.py ^
--bottleneck_dir bottleneck ^
--how_many_training_steps 200 ^
--model_dir C:/Users/Administrator/TensorFlow/inception_model/ ^
--output_graph output_graph.pb ^
--output_labels output_labels.txt ^
--image_dir data/train/
pause
点击运行bat文件发现会报错,于是上网查找资料,发现需要打开Anaconda Prompt执行:
默认每十步会打印出训练的结果
运行完之后会发现retrain文件夹里的bottleneck文件夹里的新文件
每一个图片对应一个txt文件
output_labels.txt文件内容:
之后就可以调用output_graph.pb模型文件来对自己的图片做分类。
在网上下载了一些图片放入新建的image文件夹来测试一下模型。
代码如下:
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
lines = tf.gfile.GFile('retrain/output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
#去掉换行符
line=line.strip('\n')
uid_to_human[uid] = line
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
#遍历目录
for root,dirs,files in os.walk('retrain/image/'):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据
#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#显示图片
img=Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
#排序
top_k = predictions.argsort()[::-1]
print(top_k)
for node_id in top_k:
#获取分类名称
human_string = id_to_string(node_id)
#获取该分类的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()
运行结果如下:
可见基本都能识别成功。