【代码阅读】WarpGAN: Automatic Caricature Generation

本文详细介绍了WarpGAN的代码结构,包括train.py中的主函数、warpgan.py中的网络结构和前向传播,以及default.py中的网络详细定义。通过分析,揭示了特征点如何在训练过程中生成,以及网络如何结合风格转换和几何变形来缩小照片与卡通之间的差距。此外,还讨论了WarpGAN与CariGANs的相似之处。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

代码链接

参考书籍:《Tensorflow 实战Google深度学习框架》

我觉得看一下第三章可以更清晰的了解tensorflow是怎么建立,训练一个神经网络的。

1. train.py

这份文件定义了主函数

def main(args):

初始化:

# Initalization for running
    if config.save_model:
        log_dir = utils.create_log_dir(config, config_file)
        summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model:
        network.restore_model(config.restore_model, config.restore_scopes)

    proc_func = lambda images: preprocess(images, config, True)
    trainset.start_batch_queue(config.batch_size, proc_func=proc_func)

这里的config参数设置都来自于文件 WarpGAN\config\default.py

数据集读取初始化等操作来自于文件 WarpGAN\utils\dataset.py

主循环:

# Main Loop
    print('\nStart Training\nname: {}\n# epochs: {}\nepoch_size: {}\nbatch_size: {}\n'.format(
            config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0
    start_time = time.time()
    for epoch in range(config.num_epochs):

        if epoch == 0: test(network, config, log_dir, global_step)

        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(global_step, config)
            batch = trainset.pop_batch_queue()

            wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)

            wl['lr'] = learning_rate

            # Display
            if step % config.summary_interval == 0:
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                if config.save_model:
                    summary_writer.add_summary(sm, global_step=global_step)

wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)

这句话是重点,调用了网络的训练

2. warpgan.py

这个文件中定义了warpgan这个网络的计算图,前向传播以及损失函数。

训练神经网络的过程可以概括为下面这三个步骤:

1)定义神经网络的结构和前向传播的输出结果

2)定义损失函数(根据前向传播的输出结果计算出来的)以及反向传播优化的算法

3)生成会话(tf.Session()),并在训练数据上反复运行反向传播优化算法

    def train(self, images_batch, labels_batch, switch_batch, learning_rate, keep_prob):
        images_A = images_batch[~switch_batch]
        images_B = images_batch[switch_batch]
        labels_A = labels_batch[~switch_batch]
        labels_B = labels_batch[switch_batch]
        scales_A = np.ones((images_A.shape[0]))
        scales_B = np.ones((images_B.shape[0]))
        feed_dict = {   self.images_A: images_A,
                        self.images_B: images_B,
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值