部分源代码示例:
模型训练代码:mmedit/models/restorers/real_basicvsr.py
数据处理代码:mmedit/datasets/pipelines/random_degradations.py
配置文件里对lq数据进行随机的退化,输入到生成网络,鉴别网络再鉴别生成网络的输出和原始图像。
@MODELS.register_module()
class RealBasicVSR(RealESRGAN):
"""RealBasicVSR model for real-world video super-resolution.
Ref:
pretrained (str): Path for pretrained model. Default: None.
"""
def __init__(self,
generator,
discriminator=None,
gan_loss=None,
pixel_loss=None,
cleaning_loss=None,
perceptual_loss=None,
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
is_use_ema=True,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__(generator, discriminator, gan_loss, pixel_loss,
perceptual_loss, is_use_sharpened_gt_in_pixel,
is_use_sharpened_gt_in_percep,
is_use_sharpened_gt_in_gan, is_use_ema, train_cfg,
test_cfg, pretrained)
self.cleaning_loss = build_loss(
cleaning_loss) if cleaning_loss else None
def train_step(self, data_batch, optimizer):
"""Train step.
...
# data
lq = data_batch['lq']
gt = data_batch['gt']
# generator
fake_g_output, fake_g_lq = self.generator(lq, return_lqs=True) ##低质量的图片,输入到生成网络,得到fake图片。fake_g_output应该是模型的输出,fake_g_lq不太确定是什么。
losses = dict()
log_vars = dict()
fake_g_output = fake_g_output.view(-1, c, h, w) ##输出转化一下维度
if (self.step_counter % self.disc_steps == 0
and self.step_counter >= self.disc_init_steps):
###选择其中一种loss计算
if self.pixel_loss:
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel)
if self.cleaning_loss:
losses['loss_clean'] = self.cleaning_loss(fake_g_lq, gt_clean)
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(
fake_g_output, gt_percep)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style
# gan loss for generator,让鉴别器去鉴别假的数据
if self.gan_loss:
fake_g_pred = self.discriminator(fake_g_output)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)
# parse loss
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)
# optimize
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()
# discriminator
if self.gan_loss:
set_requires_grad(self.discriminator, True)
# real
real_d_pred = self.discriminator(gt_gan)
loss_d_real = self.gan_loss(
real_d_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self.parse_losses(
dict(loss_d_real=loss_d_real))
optimizer['discriminator'].zero_grad()
loss_d.backward()
log_vars.update(log_vars_d)
# fake
fake_d_pred = self.discriminator(fake_g_output.detach())
loss_d_fake = self.gan_loss(
fake_d_pred, target_is_real=False, is_disc=True)
loss_d, log_vars_d = self.parse_losses(
dict(loss_d_fake=loss_d_fake))
loss_d.backward()
log_vars.update(log_vars_d)
optimizer['discriminator'].step()
self.step_counter += 1
log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))
return outputs