分类任务笔记
文章目录
这是一个分类任务
net.py 是网络定义
dataset.py 是数据集接口定义
train.py 是训练文件,run_train.sh是脚本
inference.py 是测试文件,run_inference.sh是脚本
0.说明
1.sys.argv[]
从程序外部获取参数的桥梁
# test.py
import sys
a = sys.argv[0]
print(a)
>>> python test.py what
test.py
# 将给为sys.argv[1]
>>> python test.py what
what
2.tf.convert_to_tensor()
用于将不同数据变成张量:比如可以让数组变成张量、也可以让列表变成张量
<class 'tensorflow.python.framework.ops.Tensor'>
3.tf.argmax()
tf.argmax(input,axis)根据axis取值的不同返回每行或者每列最大值的索引
4. tf.equal()
tf.equal(x, y, name=None) 就是判断,x, y 是不是相等
如果相等就是True,不相等,就是False
5.tf.cast()
数据类型转换
1.dataset.py 数据集接口定义
'''ImageData 类
输入
@ txtfile 文本文件
@ batch_size 批量迭代大小 64
@ num_classes 分类类别 2
@ image_size 图像大小(48*48)
输出
@ img
@ label
'''
__init__
'''初始化
对图像进行一些预处理
image_size
batch_size
txt_file
num_classes
buffer_size # 应用于 shuffle
'''
buffer_size = batch_size * buffer_scale # 64*100
read_txt_file
img_paths # 图片路径
labels # 图片标签
import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import convert_to_tensor
import numpy as np
class ImageData:
def read_txt_file(self):
self.img_paths = [] # 图片路径
self.labels = [] # 标签
for line in open(self.txt_file, 'r'):
items = line.split(' ')
self.img_paths.append(items[0])
self.labels.append(int(items[1]))
def __init__(self, txt_file, batch_size, num_classes,
image_size,buffer_scale=100):
self.image_size = image_size
self.batch_size = batch_size
self.txt_file = txt_file ##txt list file,stored as: imagename id
self.num_classes = num_classes
buffer_size = batch_size * buffer_scale
# 读取图片
self.read_txt_file()
self.dataset_size = len(self.labels) # 训练集的数量
print ("num of train datas=",self.dataset_size)
# 转换成Tensor
self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string)
self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32)
# 创建数据集
data = tf.data.Dataset.from_tensor_slices((self.img_paths, self.labels))
data = tf.data.Dataset.from_tensor_slices((self.img_paths, self.labels))
print ("data type=",type(data))
data = data.map(self.parse_function)
data = data.repeat(1000)
data = data.shuffle(buffer_size=buffer_size)
# 设置self data Batch
self.data = data.batch(batch_size)
print ("self.data type=",type(self.data))
def augment_dataset(self, image, size):
distorted_image = tf.image.random_brightness(image, # 在某范围随机调整图片亮度
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image, # 在某范围随机调整图片对比度
lower=0.2, upper=1.8)
# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(distorted_image) # 归一化
return float_image
def parse_function(self, filename, label):
label_ = tf.one_hot(label, self.num_classes) # 标签编码
img = tf.read_file(filename) # 读取
img = tf.image.decode_jpeg(img, channels=3) # 图像解码
img = tf.image.convert_image_dtype(img, dtype = tf.float32) # 改变图像数据类型
img = tf.random_crop(img,[self.image_size[0],self.image_size[1],3]) # 随机剪裁
img = tf.image.random_flip_left_right(img) # 随机翻转
img = self.augment_dataset(img,self.image_size)
return img, label_
2.net.py
x size = (?,48,48,3)
relu_conv1 size = (?, 23,23,12)
relu_conv2 size = (?, 11,11,24)
relu_conv3 size = (?, 5,5,48)
dense size = (?, 128)
'''
输入
@ x 输入数据
@ istraining 是否为训练
'''
import tensorflow as tf
def simpleconv3(x,istraining):
x_shape = tf.shape(x)
with tf.name_scope("simpleconv3"):
with tf.variable_scope("conv3_net"):
conv1 = tf.layers.conv2d(x, name="conv1", filters=12,kernel_size=[3,3], strides=(2,2), activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),bias_initializer=tf.contrib.layers.xavier_initializer())
bn1 = tf.layers.batch_normalization(conv1, training=istraining, name='bn1')
conv2 = tf.layers.conv2d(bn1, name="conv2", filters=24,kernel_size=[3,3], strides=(2,2), activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),bias_initializer=tf.contrib.layers.xavier_initializer())
bn2 = tf.layers.batch_normalization(conv2, training=istraining, name='bn2')
conv3 = tf.layers.conv2d(bn2, name="conv3", filters=48,kernel_size=[3,3], strides=(2,2), activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),bias_initializer=tf.contrib.layers.xavier_initializer())
bn3 = tf.layers.batch_normalization(conv3, training=istraining, name='bn3')
conv3_flat = tf.reshape(bn3, [-1, 5 * 5 * 48])
dense = tf.layers.dense(inputs=conv3_flat, units=128, activation=tf.nn.relu,name="dense",kernel_initializer=tf.contrib.layers.xavier_initializer())
logits= tf.layers.dense(inputs=dense, units=2, activation=tf.nn.relu,name="logits",kernel_initializer=tf.contrib.layers.xavier_initializer())
return logits
3.train.py
from dataset import *
from net import simpleconv3
import sys
import os
import cv2
txtfile = sys.argv[1]
batch_size = 64
num_classes = 2
image_size = (48,48)
learning_rate = 0.0001
debug=False
if __name__=="__main__":
'''获取数据'''
dataset = ImageData(txtfile,batch_size,num_classes,image_size)
iterator = dataset.data.make_one_shot_iterator()
dataset_size = dataset.dataset_size # 训练集数据大小
batch_images,batch_labels = iterator.get_next()
'''开始训练'''
Ylogits = simpleconv3(batch_images,True)
print("Ylogits size=",Ylogits.shape)
Y = tf.nn.softmax(Ylogits)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits, labels=batch_labels)
cross_entropy = tf.reduce_mean(cross_entropy) # 损失函数
correct_prediction = tf.equal(tf.argmax(Y, 1), tf.argmax(batch_labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 计算正确率
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
'''保存模型'''
saver = tf.train.Saver()
in_steps = 100
checkpoint_dir = 'checkpoints/'
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
log_dir = 'logs/'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
summary = tf.summary.FileWriter(logdir=log_dir)
loss_summary = tf.summary.scalar("loss", cross_entropy)
acc_summary = tf.summary.scalar("acc", accuracy)
image_summary = tf.summary.image("image", batch_images)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
steps = 10000
for i in range(steps):
_,cross_entropy_,accuracy_,batch_images_,batch_labels_,loss_summary_,acc_summary_,image_summary_ = sess.run([train_step,cross_entropy,accuracy,batch_images,batch_labels,loss_summary,acc_summary,image_summary])
if i % in_steps == 0 :
print(i,"iterations,loss=",cross_entropy_,"acc=",accuracy_)
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i)
summary.add_summary(loss_summary_, i)
summary.add_summary(acc_summary_, i)
summary.add_summary(image_summary_, i)
#print "predict=",Ylogits," labels=",batch_labels
4.inference.py 测试接口
python inference.py ./checkpoints/model.ckpt-9900 ./val_shuffle.txt
'''
@ count 总的图片数量
@ posacc 正的样本
@ negcount 负的样本
'''
import tensorflow as tf
from net import simpleconv3
import sys
import numpy as np
import cv2
import os
testsize = 48
x = tf.placeholder(tf.float32, [1,testsize,testsize,3])
y = simpleconv3(x,False)
y = tf.nn.softmax(y)
lines = open(sys.argv[2]).readlines() # 从外部加载数据
count = 0
acc = 0
posacc = 0
negacc = 0
poscount = 0
negcount = 0
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
saver = tf.train.Saver()
saver.restore(sess,sys.argv[1]) # 加载模型
#test one by one, you can change it into batch inputs
for line in lines:
imagename,label = line.strip().split(' ')
img = tf.read_file(imagename) # 读取图片
img = tf.image.decode_jpeg(img,channels = 3) # 图像解码
img = tf.image.convert_image_dtype(img,dtype = tf.float32) # 改变图像数据类型
img = tf.image.resize_images(img,(testsize,testsize),method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
img = tf.image.per_image_standardization(img) # 归一化
imgnumpy = img.eval() # 将tensor对象转成数组 imgnumpy(48,18,3)
imgs = np.zeros([1,testsize,testsize,3],dtype=np.float32) # imgs=(1,48,48,3)
imgs[0:1,] = imgnumpy
result = sess.run(y, feed_dict={x:imgs})
result = np.squeeze(result) # 即把shape中为1的维度去掉
'''结果判断'''
if result[0] > result[1]:
predict = 0
else:
predict = 1
count = count + 1
if str(predict) == '0':
negcount = negcount + 1
if str(label) == str(predict):
negacc = negacc + 1
acc = acc + 1
else:
poscount = poscount + 1
if str(label) == str(predict):
posacc = posacc + 1
acc = acc + 1
print(result)
print("acc = ",float(acc) / float(count))
print("poscount=",poscount)
print("posacc = ",float(posacc) / float(poscount))
print("negcount=",negcount)
print("negacc = ",float(negacc) / float(negcount))