# coding: utf-8
# 定义功能函数
# VGG16:输入层——卷积——卷积——池化——卷积——卷积——池化——卷积——卷积——卷积——池化——卷积——卷积——卷积——池化——卷积——卷积——卷积——池化——全连接——全连接——全连接——softmax
# In[1]:
import os
import tensorflow as tf
from time import time
import VGG16_model as model
import utils#定义了我们所用到的功能函数
from scipy.misc import imread,imresize
import numpy as np
# In[2]:
startTime=time()
batch_size=32
capacity=180#内存中存储的最大数据容量,根据自己的电脑配置而定
means=[123.68,116.779,103.939]#VGG训练时图像预处理所减均值(RGB三通道)
epoch=tf.Variable(0,name='epoch',trainable=False)#这个是不可训练的,相当于一个断点值,执行断点续训
sess=tf.Session()#声明会话
init=tf.global_variables_initializer()#调用变量
sess.run(init)#运行变量
# In[3]:
#设置检查点存储目录
ckpt_dir='./model/'
if not os.path.exists(ckpt_dir):#如果目录下不存在ckpt_dir
os.makedirs(ckpt_dir)#创建ckpt_dir文件
saver=tf.train.Saver(max_to_keep=1)#生成saver,用于保存和提取变量
#如果有检查点文件,读取最新检查点文件,恢复各种变量值
ckpt=tf.train.latest_checkpoint(ckpt_dir)
#创建summary_writer,用于写图文件
summary_writer=tf.summary.FileWriter(ckpt_dir,sess.graph)
#如果有检查点文件,恢复检查点文件,恢复各种变量值
ckpt=tf.train.latest_checkpoint(ckpt_dir)
#saver.restore(sess,'./model/')#恢复最后保存的模型
if ckpt !=None:
saver.restore(sess,ckpt)#加载所有的参数
#从这里开始就可以直接使用模型进行预测,或者接着继续训练了
else:
print('training from scratch')
#获取训练参数
start=sess.run(epoch)
print('traing starts from {} epoch'.format(start+1))
# In[ ]:
xs,ys=utils.get_file('data/train/')#获取图像列表和标签列表
image_batch,label_batch=utils.get_batch(xs,ys,224,224,batch_size,capacity)#通过读取列表来载入批量图片及标签
x=tf.placeholder(tf.float32,[None,224,224,3])
y=tf.placeholder(tf.int32,[None,2])
vgg=model.vgg16(x)#输出模型
fc8_finetuining=vgg.probs#即sofemax(fc8)微调(finetuining)sofemax(fc8)
loss_function=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=fc8_finetuining,labels=y))
optimizer=tf.train.GradientDescentOptimizer(0.001).minimize(loss_function)#GradientDescentOptimizer下降优化器
sess=tf.Session()#声明会话
init=tf.global_variables_initializer()#调用变量
sess.run(init)#运行变量
vgg.load_weights('vgg16_weights.npz',sess)#通过npz格式的文件获取VGG的相应权重参数,从而将权重注入即可实现复用
saver=tf.train.Saver()#生成saver,用于保存和提取变量
print('Model restoting......')
#saver.restore(sess,'./model/')#恢复最后保存的模型
#saver.restore(sess,'.model/epoch_00800.ckpt')恢复指定检查点的模型
#print('traing starts from {} epoch'.format(start+1))
coord=tf.train.Coordinator()#使用协调器Coordinator来管理线程
threads=tf.train.start_queue_runners(coord=coord,sess=sess)
epoch_start_time=time()
for i in range(start,1000):
images,labels=sess.run([image_batch,label_batch])
labels=utils.onehot(labels)#用one-hot对标签进行编码
sess.run(optimizer,feed_dict={x:images,y:labels})
loss=sess.run(loss_function,feed_dict={x:images,y:labels})
print('现在的损失为:%f'%loss)
epoch_end_time=time()
print('当前训练花费的时间:',(epoch_end_time-epoch_start_time))
epoch_start_time=epoch_end_time
#保存检查点
saver.save(sess,os.path.join('model/','epoch{:06d}.ckpt'.format(i)), global_step=i+1)
sess.run(epoch.assign(i+1))
print('===============Epoch %d is finished==============='%i)
#模型保存
#saver.save(sess,'./model/')
print('Optimization Finished!')
duration=time()-startTime
print('训练完成花费的时间:','{:.2f}'.format(duration))
coord.request_stop()#通知其它线程关闭
coord.join(threads)#join操作等待其他线程结束,其他所有线程关闭之后,这一函数才能返回
猫狗识别
最新推荐文章于 2023-07-26 16:12:43 发布