### 关于 Pix2Pix 的代码实现
#### 定义生成器架构
为了构建一个基本的 Pix2Pix 模型,首先定义生成器。该生成器采用 U-Net 结构来处理输入图像并生成相应的输出图像[^2]。
```python
import torch.nn as nn
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
def forward(self, x):
return self.model(x)
```
#### 实现鉴别器架构
接着是 patch-GAN 鉴别器的设计。这种类型的鉴别器专注于局部区域而非整张图片,从而提高了效率并且减少了所需参数的数量[^3]。
```python
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
super(NLayerDiscriminator, self).__init__()
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=False),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=False),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
```
#### 训练过程概览
在训练期间,对于每一对输入图及其对应的目标图,计算生成损失以及对抗性损失。通过这种方式可以优化生成器以更好地模仿目标分布,并调整鉴别器区分真假样本的能力[^5]。
```python
criterionGAN = GANLoss()
optimizer_G = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
for data in dataset:
real_A, real_B = data['A'].to(device), data['B'].to(device)
fake_B = netG(real_A)
pred_fake = netD(fake_B.detach())
loss_D_fake = criterionGAN(pred_fake, False)
pred_real = netD(real_B)
loss_D_real = criterionGAN(pred_real, True)
optimizer_D.zero_grad()
(loss_D_fake + loss_D_real).backward()
optimizer_D.step()
# Update Generator
pred_fake = netD(fake_B)
loss_G_GAN = criterionGAN(pred_fake, True)
loss_G_L1 = L1_loss(fake_B, real_B) * lambda_L1
loss_G = loss_G_GAN + loss_G_L1
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
```