上一篇文章里,我们简单介绍了StyleGAN2 Encoder的一部分源代码,即:projector.py和project_images.py,内容请参考:
轻轻松松使用StyleGAN2(五):StyleGAN2 Encoder源代码初探+中文注释,projector.py和project_images.py
其中有两个函数,涉及到加载训练数据的功能,在这篇文章里我们花一点时间来看一下。
这两个函数都在project_images.py里,分别是:
(1)dataset_tool.create_from_images(tfrecord_dir, image_dir, shuffle=0)
(2)dataset_obj = dataset.load_dataset(
data_dir=data_dir, tfrecord_dir='tfrecords',
max_label_size=0, repeat=False, shuffle_mb=0
)
下面分别对这两个函数进行介绍:
(一)第一个函数dataset_tool.create_from_images(),在./dataset_tool.py里定义,是TFRecordExporter类中的函数,其中:
create_from_images()的功能是按不同的lod(levels of detail),将对应的shape和图像信息序列化存入文件,其源代码如下(含中文注释):
def create_from_images(tfrecord_dir, image_dir, shuffle):
print('Loading images from "%s"' % image_dir)
image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
if len(image_filenames) == 0:
error('No input images found')
img = np.asarray(PIL.Image.open(image_filenames[0]))
resolution = img.shape[0]
channels = img.shape[2] if img.ndim == 3 else 1 # img.ndim是指图像数组的维度,灰度图像是2,彩色图像是3
# 举例,某个图像的shape是(1024, 1024, 3)
if img.shape[1] != resolution:
error('Input images must have the same width and height')
if resolution != 2 ** int(np.floor(np.log2(resolution))):
error('Input image resolution must be a power-of-two')
if channels not in [1, 3]:
error('Input images must be stored as RGB or grayscale')
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
# 乱序或顺序
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
for idx in range(order.size):
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
if channels == 1:
img = img[np.newaxis, :, :] # HW => CHW,高宽=>通道数·高·宽
else:
img = img.transpose([2, 0, 1]) # HWC => CHW,高宽通道数=>通道数·高·宽
# 按不同的lod,将对应的shape和图像信息序列化存入文件
tfr.add_image(img)
create_from_images()函数的核心功能由add_image()完成,以源图像尺寸为1024x1024为例,它主要是按照lod从0--8,按不同尺寸生成大小各不相同的shape和图像数据,分别序列化存储到tfrecords-r10.tfrecords到tfrecords-r02.tfrecords等9个临时文件中,其源代码(含中文注释)如下:
def add_image(self, img):
if self.print_progress and self.cur_images % self.progress_interval == 0:
print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
if self.shape is None:
self.shape = img.shape
self.resolution_log2 = int(np.log2(self.shape[1]))
assert self.shape[0] in [1, 3]
assert self.shape[1] == self.shape[2]
assert self.shape[1] == 2**self.resolution_log2
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
# 举例,如果图像是1024x1024,则lod从0--8,tfr_file命名从10--2,如:tfrecords-r05.tfrecords
for lod in range(self.resolution_log2 - 1):
tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) # 10 - lod
# TFRecordWriter是将记录写入TFRecor