tensorflow数据的生成与读取
明天就开始忙起来了,可能不会日更了,但还是希望能坚持把这个系列做完吧。
大纲:
- tensorflow的队列
- csv文件的创建和读取
- tensorflow文件的创建和读取
tensorflow的队列
队列的创建
首先,创建一个队列要选定数据的出入类型,例如FIFOQueue和RandomShffleQueue方式,前者为先入先出,后者为随即元素出列。
例:
Q = tf.FIFOQueue(3, "float")
该方法创建了一个先入先出的队列,队列数据个数为3,队列中元素类型为float
然后,可以使用enqueue_many() 函数对该队列进行数据填充。
sess = tf.Session()
init = Q.enqueue_many(([0.1, 0.2, 0.3],))
sess.run(init)
要注意的是,tensorflow中的任何操作都要在会话Session中完成,所以要等到会话的==run()==函数完成后,该队列才会真正被填充。
完整例子:
Q = tf.FIFOQueue(3, "float")
sess = tf.Session()
init = Q.enqueue_many(([0.1, 0.2, 0.3],))
init2 = Q.dequeue()
init3 = Q.enqueue(1.)
sess.run(init)
sess.run(init2)
sess.run(init3)
quelen = sess.run(Q.size())
for i in range(quelen):
print(sess.run(Q.dequeue()))
tensorflow中对队列的操作和python的队列基本相似
运行结果:
0.2
0.3
1.0
CSV文件的创建和读取
CSV(Comma-Separated Values)文件是最常用的文件存储方式,以纯文本形式存储表格数据。
这里采用图片地址和标签来写入CSV文件
图片目录如下:
创建:
import os
path = 'jpg'
filename = os.listdir(path)
strText = ""
with open("train_list.csv", "w") as fid:
for a in range(len(filename)):
strText = path+os.sep+filename[a] + "," + filename[a].split("_")[1][:1] + "\n"
fid.write(strText)
fid.close()
结果train_list.csv:
jpg\image_1.jpg,1
jpg\image_2.jpg,2
jpg\image_3.jpg,3
读取
image_add_list = []
image_label_list = []
with open("train_list.csv") as fid:
for image in fid.readlines():
image_add_list.append(image.strip().split(",")[0])
image_label_list.append(image.strip().split(",")[1])
def get_image(image_path):
return tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(image_path),
channels=1), dtype=tf.float32)
for i in range(len(image_add_list)):
print(image_label_list[i], get_image(image_add_list[i]))
其中用到的函数:
tf.read_file:读取图片地址
tf.image.decode_jpeg把读取的图片解码成jpg格式,并把通道数设为1
tf.image.convert_image_dtype把图像转换为张量
运行结果:
1 Tensor("convert_image:0", shape=(?, ?, 1), dtype=float32)
2 Tensor("convert_image_1:0", shape=(?, ?, 1), dtype=float32)
3 Tensor("convert_image_2:0", shape=(?, ?, 1), dtype=float32)
可以看到,该函数将图片和scv文件转换成了标签和张量格式
tensorflow文件的创建和读取
除了典型的CSV文件读取外,tensorflow还有专门的文件存储格式:TFRecords文件。
TFRecords文件的创建
a_data = 0.834
b_data = [17]
c_data = np.array([[0, 1, 2], [3, 4, 5]])
c = c_data.astype(np.uint8)
c_raw = c.tostring() # 转化成字符串
example = tf.train.Example(
features=tf.train.Features(
feature={
'a': tf.train.Feature(
float_list=tf.train.FloatList(value=[a_data]) # 方括号表示输入为list
),
'b': tf.train.Feature(
int64_list=tf.train.Int64List(value=b_data) # b_data本身就是列表
),
'c': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[c_raw]) # c_raw被转化为byte格式
)
}
)
)
# 创建TFRecords文件
writer = tf.python_io.TFRecordWriter("trainArray.tfrecords")
for i in range(1):
# 创建样本example
# ...
serialized = example.SerializeToString()
writer.write(serialized)
writer.close()
上面代码是TFRecords写入文件的经典格式,即对样本的序列化之后进行写操作完成。可以看到TFRecords文件接收三中数据格式,分别为BytesList,Int64List和FloatList。
TFRecords文件的读取
TFRecords文件的读取则有些麻烦,要将TFRecords中的数据以输入的格式读取出来
代码如下:
# 读取TFRecords文件
filename_queue = tf.train.input_producer(["trainArray.tfrecords"], num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([], tf.int64),
'c': tf.FixedLenFeature([], tf.string)
}
)
a = features['a']
b = features['b']
c_raw = features['c']
c = tf.decode_raw(c_raw, tf.uint8)
c = tf.reshape(c, [2, 3])
print(a)
print(b)
print(c)
输出结果:
Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(), dtype=int64)
Tensor("Reshape:0", shape=(2, 3), dtype=uint8)
太晚了,这一部分剩余内容明天再更^ ^