RealBasicVSR源码解析

部分源代码示例:
模型训练代码:mmedit/models/restorers/real_basicvsr.py
数据处理代码:mmedit/datasets/pipelines/random_degradations.py

配置文件里对lq数据进行随机的退化,输入到生成网络,鉴别网络再鉴别生成网络的输出和原始图像。

@MODELS.register_module()
class RealBasicVSR(RealESRGAN):
    """RealBasicVSR model for real-world video super-resolution.

    Ref:
            pretrained (str): Path for pretrained model. Default: None.
    """

    def __init__(self,
                 generator,
                 discriminator=None,
                 gan_loss=None,
                 pixel_loss=None,
                 cleaning_loss=None,
                 perceptual_loss=None,
                 is_use_sharpened_gt_in_pixel=False,
                 is_use_sharpened_gt_in_percep=False,
                 is_use_sharpened_gt_in_gan=False,
                 is_use_ema=True,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):

        super().__init__(generator, discriminator, gan_loss, pixel_loss,
                         perceptual_loss, is_use_sharpened_gt_in_pixel,
                         is_use_sharpened_gt_in_percep,
                         is_use_sharpened_gt_in_gan, is_use_ema, train_cfg,
                         test_cfg, pretrained)

        self.cleaning_loss = build_loss(
            cleaning_loss) if cleaning_loss else None

    def train_step(self, data_batch, optimizer):
        """Train step.
		...
		        # data
        lq = data_batch['lq']
        gt = data_batch['gt']
        # generator
        fake_g_output, fake_g_lq = self.generator(lq, return_lqs=True)       ##低质量的图片,输入到生成网络,得到fake图片。fake_g_output应该是模型的输出,fake_g_lq不太确定是什么。
        losses = dict()
        log_vars = dict()

	  fake_g_output = fake_g_output.view(-1, c, h, w)     ##输出转化一下维度
	  if (self.step_counter % self.disc_steps == 0
                and self.step_counter >= self.disc_init_steps):

###选择其中一种loss计算
            if self.pixel_loss:
                losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel)
            if self.cleaning_loss:
                losses['loss_clean'] = self.cleaning_loss(fake_g_lq, gt_clean)
            if self.perceptual_loss:
                loss_percep, loss_style = self.perceptual_loss(
                    fake_g_output, gt_percep)
                if loss_percep is not None:
                    losses['loss_perceptual'] = loss_percep
                if loss_style is not None:
                    losses['loss_style'] = loss_style

            # gan loss for generator,让鉴别器去鉴别假的数据
            if self.gan_loss:
                fake_g_pred = self.discriminator(fake_g_output)
                losses['loss_gan'] = self.gan_loss(
                    fake_g_pred, target_is_real=True, is_disc=False)

            # parse loss
            loss_g, log_vars_g = self.parse_losses(losses)
            log_vars.update(log_vars_g)

            # optimize
            optimizer['generator'].zero_grad()
            loss_g.backward()
            optimizer['generator'].step()

        # discriminator
        if self.gan_loss:
            set_requires_grad(self.discriminator, True)
            # real
            real_d_pred = self.discriminator(gt_gan)
            loss_d_real = self.gan_loss(
                real_d_pred, target_is_real=True, is_disc=True)
            loss_d, log_vars_d = self.parse_losses(
                dict(loss_d_real=loss_d_real))
            optimizer['discriminator'].zero_grad()
            loss_d.backward()
            log_vars.update(log_vars_d)

            # fake
            fake_d_pred = self.discriminator(fake_g_output.detach())
            loss_d_fake = self.gan_loss(
                fake_d_pred, target_is_real=False, is_disc=True)
            loss_d, log_vars_d = self.parse_losses(
                dict(loss_d_fake=loss_d_fake))
            loss_d.backward()
            log_vars.update(log_vars_d)

            optimizer['discriminator'].step()

        self.step_counter += 1

        log_vars.pop('loss')  # remove the unnecessary 'loss'
        outputs = dict(
            log_vars=log_vars,
            num_samples=len(gt.data),
            results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))

        return outputs
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值