Tensorflow取出cifar中的图片

本文介绍了一个Python脚本,用于从CIFAR-10数据集中读取并展示图片数据。该脚本使用TensorFlow进行数据读取,并通过OpenCV将读取到的数据保存为JPEG格式的图片。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

想取出cifar中的文件看一看,脚本如下:

import os
import cv2
import numpy as np

from six.moves import xrange
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', '/home/jdnie/cifar10_data',
                           """Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_string('out_dir', '/home/jdnie/cifar10_data/output',
                           """Path to output the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_integer('max_images', 1000000,
                            """Number of images to save.""")

def read_cifar10():
	if not FLAGS.data_dir:
		raise ValueError('Please supply a data_dir')
	data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
	filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
				 for i in xrange(1, 6)]
	for f in filenames:
		if not tf.gfile.Exists(f):
			raise ValueError('Failed to find file: ' + f)
	
	filename_queue = tf.train.string_input_producer(filenames)	
	
	label_bytes = 1
	height = 32
	width = 32
	depth = 3
	image_bytes = height * width * depth
	record_bytes = label_bytes + image_bytes	
	reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
	key, value = reader.read(filename_queue)
	record_bytes = tf.decode_raw(value, tf.uint8)
	label = tf.cast(
		tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
	depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
							 [depth, height, width])
	uint8image = tf.transpose(depth_major, [1, 2, 0])
	print(uint8image)
	
	return uint8image, label
	

def train():
	with tf.Graph().as_default():	
		image = read_cifar10()
		#init = tf.initialize_all_variables()
		sess = tf.InteractiveSession()
		#sess.run(init)
		tf.train.start_queue_runners(sess=sess)
		
		dest_directory = FLAGS.out_dir
		if not os.path.exists(dest_directory):
			os.makedirs(dest_directory)
		
		for step in xrange(FLAGS.max_images):
			image2, label = sess.run(image)
			#print(cv_image)
			output_filename = os.path.join(dest_directory, 'data_batch_%d_%d.jpg' % (step, label))
			cv2.imwrite(output_filename, image2)
			#cv2.imshow('cifar', image2)
			#cv2.waitKey(0)
		

def main(argv=None):
	train()

if __name__ == '__main__':
  tf.app.run()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值