实例7:将图片文件制作成Dataset数据集
在图片训练过程中,一个变形丰富的数据集会使模型的精度与泛化性成倍提升
1. 代码实现:读取样本文件的目录与标签
定义load_sample函数,用来将样本图片的目录名称与对应的标签读入内存。
import os
import tensorflow as tf
from PIL import Image
import numpy as np
from tqdm import tqdm
from sklearn.utils import shuffle
def load_sample(sample_dir, shuffle):
print("loading sample dataset...")
lfilenames = []
labelsnames = []
for (dirpath, dirnames, filenames) in os.walk(sample_dir):
for filename in filenames :
filename_path = os.sep.join([dirpath, filename])
lfilenames.append(filename_path)
labelsnames.append(dirpath.split('\\')[-1])
lab = list(sorted(set(labelsnames)))
labdict = dict(zip(lab, list(range(len(lab)))))
labels = [labdict[i] for i in labelsnames]
if shuffle == True:
return shuffle(np.asarray(lfilenames), np.asarray(labels)), np.asarray(lab)
else:
return np.asarray(lfilenames), np.asarray(labels), np.asarray(lab)
2. 代码实现:定义函数,实现函数转换操作
定义函数_distorted_image,用TensorFlow自带的API实现单一图片的变换处理
def _distorted_image(image, size, ch=1, shuffleflag=False, cropflag=False, brightnessflag=False, contrastflag=False):
distorted_image =tf.image.random_flip_left_right(image)
if cropflag == True: #随机裁剪
s = tf.random_uniform((1,2),int(size[0]*0.8),size[0],tf.int32)
distorted_image = tf.random_crop(distorted_image, [s[0][0],s[0][0],ch])
distorted_image = tf.image.random_flip_up_down(distorted_image)#上下随机翻转
if brightnessflag == True:#随机变化亮度
distorted_image = tf.image.random_brightness(distorted_image,max_delta=10)
if contrastflag == True: #随机变化对比度
distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)
if shuffleflag==True:
distorted_image = tf.random_shuffle(distorted_image)#沿着第0维乱序
return distorted_image
3. 代码实现:用自定义函数实现图片归一化
def _norm_image(image,size,ch=1,flattenflag = False): #定义函数,实现归一化,并且拍平
image_decoded = image/255.0
if flattenflag==True:
image_decoded = tf.reshape(image_decoded, [size[0]*size[1]*ch])
return image_decoded
本实例将图片的值域变成0~1之间的小数,实际开发中,也可以将图片的值域变成-1-1之间的小数
4. 代码实现:用第三方函数将图片旋转30度
定义函数random_rotated30实现图片旋转功能,用skimage库函数将图片旋转30度
在整个数据集的处理流程中,对图片的操作丢失基于张量进行的,所以第三方函数无法操作TensorFlow中张量,所以需要额外的封装
用tf.py_function函数可以将第三方 库函数成一个TensorFlow的中操作符(op)
from skimage import transform
def _random_rotated30(image, label): #定义函数实现图片随机旋转操作
def _rotated(image): #封装好的skimage模块,来进行图片旋转30度
shift_y, shift_x = np.array(image.shape.as_list()[:2],np.float32) / 2.
tf_rotate = transform.SimilarityTransform(rotation=np.deg2rad(30))
tf_shift = transform.SimilarityTransform(translation=[-shift_x, -shift_y])
tf_shift_inv,image.size = transform.SimilarityTransform(translation=[shift_x, shift_y]),image.shape#兼容transform函数
image_rotated = transform.warp(image, (tf_shift + (tf_rotate + tf_shift_inv)).inverse)
return image_rotated
def _rotatedwrap():
image_rotated = tf.py_function( _rotated,[image],[tf.float64]) #调用第三方函数
return tf.cast(image_rotated,tf.float32)[0]
a = tf.random_uniform([1],0,2,tf.int32)#实现随机功能
image_decoded = tf.cond(tf.equal(tf.constant(0),a[0]),lambda: image,_rotatedwrap)
return image_decoded, label
使用TensorFlow中的tf.cond方法,用来根据随机条件判断是否需要对本次图片进行旋转
5. 代码实现:定义函数,生成Dataset对象
咋dataset函数转给你,用内置函数_parseone将所有文件名转化为具体的图片内容,并返回Dataset队形
def dataset(directory,size,batchsize,random_rotated=False):#定义函数,创建数据集
(filenames,labels),_ =load_sample(directory,shuffleflag=False) #载入文件名称与标签
def _parseone(filename, label): #解析一个图片文件
image_string = tf.read_file(filename) #读取整个文件
image_decoded = tf.image.decode_image(image_string)
image_decoded.set_shape([None, None, None]) # 必须有这句,不然下面会转化失败
image_decoded = _distorted_image(image_decoded,size)#对图片做扭曲变化
image_decoded = tf.image.resize(image_decoded, size) #变化尺寸
image_decoded = _norm_image(image_decoded,size)#归一化
image_decoded = tf.cast(image_decoded,dtype=tf.float32)
label = tf.cast( tf.reshape(label, []) ,dtype=tf.int32 )#将label 转为张量
return image_decoded, label
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))#生成Dataset对象
dataset = dataset.map(_parseone) #有图片内容的数据集
if random_rotated == True:
dataset = dataset.map(_random_rotated30)
dataset = dataset.batch(batchsize) #批次划分数据集
return dataset
5. 代码实现:建立会话,输出数据
def showresult(subplot,title,thisimg): #显示单个图片
p =plt.subplot(subplot)
p.axis('off')
p.imshow(thisimg)
p.set_title(title)
def showimg(index,label,img,ntop): #显示
plt.figure(figsize=(20,10)) #定义显示图片的宽、高
plt.axis('off')
ntop = min(ntop,9)
print(index)
for i in range (ntop):
showresult(100+10*ntop+1+i,label[i],img[i])
plt.show()
def getone(dataset):
iterator = dataset.make_one_shot_iterator() #生成一个迭代器
one_element = iterator.get_next() #从iterator里取出一个元素
return one_element
sample_dir="man_woman"
size = [96,96]
batchsize = 10
tdataset = dataset(sample_dir,size,batchsize)
tdataset2 = dataset(sample_dir,size,batchsize,True)
print(tdataset.output_types) #打印数据集的输出信息
print(tdataset.output_shapes)
one_element1 = getone(tdataset) #从tdataset里取出一个元素
one_element2 = getone(tdataset2) #从tdataset2里取出一个元素
with tf.Session() as sess: # 建立会话(session)
sess.run(tf.global_variables_initializer()) #初始化
try:
for step in np.arange(1):
value = sess.run(one_element1)
value2 = sess.run(one_element2)
showimg(step,value[1],np.asarray( value[0]*255,np.uint8),10) #显示图片
#showimg(step,value2[1],np.asarray( value2[0]*255,np.uint8),10) #显示图片
except tf.errors.OutOfRangeError: #捕获异常
print("Done!!!")