轻轻松松使用StyleGAN2(六):StyleGAN2 Encoder是怎样加载训练数据的?源代码+中文注释,dataset_tool.py和dataset.py

本文深入解析StyleGAN2 Encoder的两个关键数据加载函数:dataset_tool.create_from_images()和dataset.load_dataset()。create_from_images()在TFRecordExporter类中,用于按不同分辨率将图像序列化存储;load_dataset()在TFRecordDataset类中,负责从tfrecords文件加载并预处理数据。文章包含源代码及中文注释。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上一篇文章里,我们简单介绍了StyleGAN2 Encoder的一部分源代码,即:projector.py和project_images.py,内容请参考:

轻轻松松使用StyleGAN2(五):StyleGAN2 Encoder源代码初探+中文注释,projector.py和project_images.py

其中有两个函数,涉及到加载训练数据的功能,在这篇文章里我们花一点时间来看一下。

这两个函数都在project_images.py里,分别是:

(1)dataset_tool.create_from_images(tfrecord_dir, image_dir, shuffle=0)

(2)dataset_obj = dataset.load_dataset(
        data_dir=data_dir, tfrecord_dir='tfrecords',
        max_label_size=0, repeat=False, shuffle_mb=0
    )

下面分别对这两个函数进行介绍:

(一)第一个函数dataset_tool.create_from_images(),在./dataset_tool.py里定义,是TFRecordExporter类中的函数,其中:

create_from_images()的功能是按不同的lod(levels of detail),将对应的shape和图像信息序列化存入文件,其源代码如下(含中文注释):

def create_from_images(tfrecord_dir, image_dir, shuffle):
    print('Loading images from "%s"' % image_dir)
    image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
    if len(image_filenames) == 0:
        error('No input images found')

    img = np.asarray(PIL.Image.open(image_filenames[0]))
    resolution = img.shape[0]
    channels = img.shape[2] if img.ndim == 3 else 1 # img.ndim是指图像数组的维度,灰度图像是2,彩色图像是3
    # 举例,某个图像的shape是(1024, 1024, 3)
    if img.shape[1] != resolution:
        error('Input images must have the same width and height')
    if resolution != 2 ** int(np.floor(np.log2(resolution))):
        error('Input image resolution must be a power-of-two')
    if channels not in [1, 3]:
        error('Input images must be stored as RGB or grayscale')

    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        # 乱序或顺序
        order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
        for idx in range(order.size):
            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
            if channels == 1:
                img = img[np.newaxis, :, :] # HW => CHW,高宽=>通道数·高·宽
            else:
                img = img.transpose([2, 0, 1]) # HWC => CHW,高宽通道数=>通道数·高·宽
            # 按不同的lod,将对应的shape和图像信息序列化存入文件
            tfr.add_image(img)

create_from_images()函数的核心功能由add_image()完成,以源图像尺寸为1024x1024为例,它主要是按照lod从0--8,按不同尺寸生成大小各不相同的shape和图像数据,分别序列化存储到tfrecords-r10.tfrecords到tfrecords-r02.tfrecords等9个临时文件中,其源代码(含中文注释)如下:

    def add_image(self, img):
        if self.print_progress and self.cur_images % self.progress_interval == 0:
            print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
        if self.shape is None:
            self.shape = img.shape
            self.resolution_log2 = int(np.log2(self.shape[1]))
            assert self.shape[0] in [1, 3]
            assert self.shape[1] == self.shape[2]
            assert self.shape[1] == 2**self.resolution_log2
            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
            # 举例,如果图像是1024x1024,则lod从0--8,tfr_file命名从10--2,如:tfrecords-r05.tfrecords
            for lod in range(self.resolution_log2 - 1):
                tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) # 10 - lod
                # TFRecordWriter是将记录写入TFRecor
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值