一、tfrecords文件
tfrecords是一种二进制文件,可先将图片与标签制作成该格式的文件,使用tfrecords进行数据读取,会提高内存利用率,将不同输入文件统一起来。
二、mnist数据集
MNIST数据集是一个手写体数字集合,可到此处下载,数据包括四部分:训练图片集、训练标签集、测试图片集、测试标签集。该数据集的训练集中有55000张图片,验证集中5000有张图片,测试集中是10000张图片。
三、数据制作与读取
1、生成文件
文件生成的过程:
- 新建一个writer
- for循环遍历每张图和标签
- 把每张图和标签封装到example中
- 将example序列化
具体代码如下:
def generate_tfRecord():
isExists=os.path.exists(data_path) ##判断保存路径是否存在
if not isExists:
os.makedirs(data_path)
print('路径创建成功')
else:
print('路径已存在')
write_tfRecord(tfRecord_train, image_train_path, label_train_path) ##使用自定义函数将训练集生成名叫tfRecord_train的tfrecords文件
write_tfRecord(tfRecord_test, image_test_path, label_test_path) ##同理训练集
def write_tfRecord(tfRecordName, image_path, label_path):
writer=tf.python_io.TFRecordWriter(tfRecordName) ##创建一个writer
num_pic=0 ##计数器
f=open(label_path,'r') ##以读的形式打开标签文件
contents = f.readlines() ##读取整个文件内容
f.close()
for content in contents:
value=content.split() ##以空格分隔每行的内容,分割后组成列表value
img_path=image_path+value[0]
img=Image.open(img_path) ##打开图片
img_raw=img.tobytes() ##将图片转换为二进制数据
labels=[0]*10
labels[int(value[1])]=1 ##将labels所对应的标签为赋值为1
example=tf.train.Example(features=tf.train.Features(feature={ ##创建一个example,用一个features进行封装
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), ##在img_raw放入二进制图片
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))})) ##在labels放入图片对应的标签
writer.write(example.SerializeToString()) ##将example进行序列化
num_pic+=1
print('图片数为:',num_pic)
writer.close()
print('tfREcord文件写入成功')
其中标签文件的格式为:
2、读取文件
文件读取的过程:
- 新建一个reader
- 解序列化example读取图片和标签
- 将图片和标签转化为网络需要的格式
具体代码如下:
###实现了批获取训练集或测试集的图片和标签
def get_tfRecord(num, isTrain=True): ##参数num表示一次读取多少组
if isTrain:
tfRecord_path=tfRecord_train
else:
tfRecord_path=tfRecord_test
img,label=read_tfRecord(tfRecord_path)
img_batch, label_batch= tf.train.shuffle_batch([img, label],
batch_size=num,
capacity=1000,
min_after_dequeue=700,
num_threads=2)
def read_tfRecord(tfRecord_path):
filmname_queue=tf.train.string_input_producer([tfRecord_path])
reader=tf.TFRecordReader() ##新建一个reader
_,serialized_example = reader.read(filmname_queue) ##将读出的每一个样本保存到serialized_example中进行解序列化
features=tf.parse_single_example(serialized_example,features={ ##将图片和标签的键值要和制作数据集是的键值相同
'img_raw':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([10],tf.int64)})
img=tf.decode_raw(features['img_raw'],tf.uint8) ##将img_raw字符串转化为8位无符号整型
img.set_shape([784])
img=tf.cast(img,tf.float32)*(1./255) ##转化为浮点数形式
label=tf.cast(features['label'],tf.float32)
return img, label