本文记录一下TensorFLow的几种图片读取方法,官方文档有较为全面的介绍。
1.使用gfile读图片,decode输出是Tensor,eval后是ndarray
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
print(tf.__version__)
image_raw = tf.gfile.FastGFile('test/a.jpg','rb').read()
img = tf.image.decode_jpeg(image_raw)
with tf.Session() as sess:
print(type(image_raw))
print(type(img))
print(type(img.eval()))
print(img.eval().shape)
print(img.eval().dtype)
plt.figure(1)
plt.imshow(img.eval())
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
输出为:
1.3.0
<class 'bytes'>
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'numpy.ndarray'>
(666, 1000, 3)
uint8
图片显示(略)
2.使用WholeFileReader输入queue,decode输出是Tensor,eval后是ndarray
import tensorflow as tf
import os
import matplotlib.pyplot as plt
def file_name(file_dir):
for root, dirs, files in os.walk(file_dir):
print(root)
print(dirs)
print(files)
def file_name2(file_dir):
L=[]
for root, dirs, files in os.walk(file_dir):
for file in files:
if os.path.splitext(file)[1] == '.jpg':
L.append(os.path.join(root, file))
return L
path = file_name2('test')
file_queue = tf.train.string_input_producer(path, shuffle=True, num_epochs=2)
image_reader = tf.WholeFileReader()
key, image = image_reader.read(file_queue)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
tf.local_variables_initializer().run()
threads = tf.train.start_queue_runners(sess=sess)
for _ in path+path:
plt.figure
plt.imshow(image.eval())
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
3.使用read_file,decode输出是Tensor,eval后是ndarray
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
print(tf.__version__)
image_value = tf.read_file('test/a.jpg')
img = tf.image.decode_jpeg(image_value, channels=3)
with tf.Session() as sess:
print(type(image_value))
print(type(img))
print(type(img.eval()))
print(img.eval().shape)
print(img.eval().dtype)
plt.figure(1)
plt.imshow(img.eval())
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
输出是:
1.3.0
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'numpy.ndarray'>
(666, 1000, 3)
uint8
显示图片(略)
4.TFRecords:
有空再看。
如果图片是根据分类放在不同的文件夹下,那么可以直接使用如下代码:
http://blog.youkuaiyun.com/u012759136/article/details/52232266
https://www.2cto.com/kf/201702/604326.html
需要稍微整理的话可以参考:
http://blog.youkuaiyun.com/hjxu2016/article/details/76165559