想取出cifar中的文件看一看,脚本如下:
import os
import cv2
import numpy as np
from six.moves import xrange
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', '/home/jdnie/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_string('out_dir', '/home/jdnie/cifar10_data/output',
"""Path to output the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_integer('max_images', 1000000,
"""Number of images to save.""")
def read_cifar10():
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames)
label_bytes = 1
height = 32
width = 32
depth = 3
image_bytes = height * width * depth
record_bytes = label_bytes + image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
label = tf.cast(
tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
[depth, height, width])
uint8image = tf.transpose(depth_major, [1, 2, 0])
print(uint8image)
return uint8image, label
def train():
with tf.Graph().as_default():
image = read_cifar10()
#init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
#sess.run(init)
tf.train.start_queue_runners(sess=sess)
dest_directory = FLAGS.out_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
for step in xrange(FLAGS.max_images):
image2, label = sess.run(image)
#print(cv_image)
output_filename = os.path.join(dest_directory, 'data_batch_%d_%d.jpg' % (step, label))
cv2.imwrite(output_filename, image2)
#cv2.imshow('cifar', image2)
#cv2.waitKey(0)
def main(argv=None):
train()
if __name__ == '__main__':
tf.app.run()