即使queue_runner中shuffle=False 在多线程的循环的时候还是谁快谁先出
感觉是异步的。
所以只好fetch,将文件名字也抛出来,来解决wenj文件名字对不上的问题。
"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \
--input input_sample.jpg \
--output output_sample.jpg \
--image_size 128
"""
import tensorflow as tf
import os
from model import CycleGAN
import utils
from glob import glob
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('model', '', 'model path (.pb)')
tf.flags.DEFINE_string('input_dir', 'samples/input_n2v', 'input image path (.jpg)')
tf.flags.DEFINE_string('output_dir', 'samples/output_n2v', 'output image path (.jpg)')
tf.flags.DEFINE_integer('image_size', '128', 'image size, default: 128')
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC'
config.gpu_options.allow_growth = True
def get_all_org_files(file_dir):
L = []
for root, dirs, files in os.walk(file_dir):
for file in files:
if os.path.splitext(file)[1] == '.bmp':
L.append(os.path.join(root, file))
return L
# 其中os.path.splitext()函数将路径拆分为文件名+扩展名
def inference():
graph = tf.Graph()
# old_img_file_path = get_all_org_files(FLAGS.input_dir)
# !!!改
model_name = FLAGS.model.split('/')[-1]
if model_name == 'n2v.pb':
input_shape = [FLAGS.image_size, FLAGS.image_size, 1]
input_channel = 1
input_name = 'input_image_x'
else:
input_shape = [FLAGS.image_size, FLAGS.image_size, 3]
input_channel = 3
input_name = 'input_image_y'
with graph.as_default():
# 设置decoder
# !!!改 linux
paths = glob("{}/*.{}".format(FLAGS.input_dir, 'bmp')) # 使用通配符进行文件查找,找到第一个文件名字为jpg就把tf的decode设置为jpg
tf_decode = tf.image.decode_bmp
# 到这里结束,path是已经定义好的,相当于所有的图片名称已经确定了
# shuffle 默认为true 最好设置为false 否则输入queue的顺序不一定一致
filename_queue = tf.train.string_input_producer(list(paths), shuffle=False)
# 读取queue数据的reader
reader = tf.WholeFileReader()
# 读取图片数据,这只是定义,并没有真正去读取,在sess中真正读取图片
filename, data = reader.read(filename_queue)
# 解码取出到image里面,decode,必须在session里面运行
image = tf_decode(data, channels=input_channel) # 使用channels=0是直接使用bmp的channel数据,应该也行,好像不行
# image是数据的来源
# reshape这个tensor,将维度改变
image.set_shape(input_shape)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# 只能读取一张image的情况,这里果然不能放一个Tensor,这样的做法是错的!!!
# input = tf.placeholder(dtype=tf.string)
# with tf.gfile.FastGFile(input, 'rb') as f:
# image_data = f.read()
# input_image = tf.image.decode_jpeg(image_data, channels=1)
# input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
# input_image = utils.convert2float(input_image)
# input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 1])
with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_file.read())
# !!!改 ???前面export_graph name改了,这边是不是也要改 是的!
[output_image] = tf.import_graph_def(graph_def,
input_map={input_name: image},
return_elements=['output_image:0'],
name='output')
with tf.Session(graph=graph,config=config) as sess:
coord = tf.train.Coordinator() # 协同启动的线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 启动线程运行队列
for i in range(len(paths)):
generated, it_path = sess.run((output_image, filename))
# import pdb;pdb.set_trace()
it_name = it_path.decode('utf8').split('/')[-1]
with open('{}/{}'.format(FLAGS.output_dir,it_name), 'wb') as f:
f.write(generated)
# for i in range(len(paths)):
# generated = output_image.eval()
# with open('{}/{:05d}.jpg'.format(FLAGS.output_dir,i+1), 'wb') as f:
# f.write(generated)
coord.request_stop() # 停止所有的线程
coord.join(threads)
def main(unused_argv):
inference()
if __name__ == '__main__':
tf.app.run()