今天尝试用keras实现SimGan,在读入数据时遇到了几个坑,记录一下。
# coding: utf-8
import os
import sys
import keras
from keras import applications
from keras import layers
from keras import models
from keras import optimizers
from keras.preprocessing import image
import h5py
import numpy as np
import tensorflow as tf
from hs_util import *
from gen_captcha import *
from const import *
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 指定哪张卡:0, 1, 2, ...
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # 随着进程逐渐增加显存占用,而不是一下占满
config.gpu_options.per_process_gpu_memory_fraction = 0.2
sess = tf.Session(config = config)
# 图像大小
img_height = IMAGE_HEIGHT
img_width = IMAGE_WIDTH
img_channels = NUM_CHANNELS
#定义图像数据生成器
# datagen = image.ImageDataGenerator(
# preprocessing_function=applications.xception.preprocess_input, #注意这里,如果使用预处理函数,则最后返回的图像取值在-1到1之间,还跟具体用什么底层框架有关,这里需要看源码。
# data_format='channels_last')
datagen = image.ImageDataGenerator(
# preprocessing_function=applications.xception.preprocess_input, # 不使用预处理函数,最后图像取值在0到255之间,但是像素值为float类型
data_format='channels_last')
flow_from_directory_params = {'target_size': (img_height, img_width),
'color_mode': 'grayscale' if img_channels == 1 else 'rgb',
'class_mode': None,
'batch_size': batch_size}
flow_params = {'batch_size': batch_size}
real_generator = datagen.flow_from_directory(
'./data/real',# 保存真实数据的本地路径,注意在这个路径下还有子文件夹,子文件夹下是图片数据
**flow_from_directory_params
)
real_image_batch = get_image_batch_real(real_generator)
plot_batch(real_image_batch) #显示batch数据
def plot_batch(image_batch1,image_batch2, figure_path,):
fig, axs = plt.subplots(window_row,4)
for i in range(window_row):
axs[i,0].imshow(image_batch1[i*4]/255,cmap='gray') #因为ImageDataGenerater生成的数据是0-255的float类型,所以需要/255做归一化才能正常显示。
axs[i,1].imshow(image_batch1[i*4+1]/255, cmap='gray') # cmap的参数设置也需要看一下plt的文档
axs[i,2].imshow(image_batch1[i*4+2]/255, cmap='gray')
axs[i,3].imshow(image_batch1[i*4+3]/255, cmap='gray')
plt.show()