目录
引言
ESRGAN (Enhanced Super-Resolution Generative Adversarial Networks) 是对 SRGAN 的改进,能够生成在视觉上更具“真实感”且拥有更锐利高频细节的超分辨率图像。SRGAN 通过引入对抗损失(GAN)与感知损失(Perceptual Loss)已经在视觉质量上显著提升,但仍然存在纹理失真、放大伪影等问题。ESRGAN 在网络结构(如 RRDB 模块)与对抗训练方式(相对判别器)上做出改进,从而在多个数据集上取得更佳的主观和客观评估效果。
ESRGAN 的核心思路
-
RRDB 主体网络
- 采用 “残差中的残差” 设计,并融合 DenseNet 风格的 RDB (Residual Dense Block),去除 BN 层,保证在深层网络中能稳定地提炼丰富高频特征。
-
相对判别器(Relativistic Discriminator)
- 不只是简单地判定“真 / 假”,而是判定真实图和生成图之间的相对真度差异;
- 引入类似 RaGAN(Relativistic Average GAN) 的方法,提高判别器对细微纹理差异的敏感度。
-
感知损失(Perceptual Loss) + (可选)像素损失
- 在高层特征空间(VGG 等)度量生成图与真实图的距离,让网络注重纹理结构的一致性;
- 同时可保留少量 ℓ 1 \ell_1 ℓ1 或 ℓ 2 \ell_2 ℓ2 像素损失,辅助稳定训练。
-
残差缩放
- 在 RRDB 中,对输出乘以较小的缩放因子(如 0.2)以缓解梯度爆炸,并稳定大规模训练。
网络结构与数学表达
RRDB 主体网络(Residual in Residual Dense Block)
ESRGAN 的生成器使用 RRDB 代替传统 ResBlock 作为基本构件。RRDB 本身由三个 RDB 串行组成,并在最外层包裹一个残差连接 + 缩放因子。
-
RDB (Residual Dense Block)
在 RDB 中,每一层卷积的输入都是所有先前卷积层的输出拼接(Dense 连接),并在输出端加一个局部残差。可写作:
x d = σ ( W d ∗ [ x 0 , x 1 , . . . , x d − 1 ] + b d ) , d = 1 … D , x_{d} = \sigma\bigl(W_d * [x_0, x_1, ..., x_{d-1}] + b_d\bigl), \quad d = 1 \dots D, xd=σ(Wd∗[x0,x1,...,xd−1]+bd),d=1…D,
RDB ( x 0 ) = x 0 + W lff ∗ [ x 0 , x 1 , . . . , x D ] , \text{RDB}(x_0) = x_0 + W_{\text{lff}}^{} * [x_0, x_1, ..., x_D], RDB(x0)=x0+Wlff∗[x0,x1,...,xD],
其中 D D D 是 RDB 内部卷积层数, σ \sigma σ 多用 LeakyReLU 或 ReLU。 -
RRDB = RDB + RDB + RDB + 残差
RRDB ( x ) = x + β ⋅ ( F RDB 3 ∘ F RDB 2 ∘ F RDB 1 ( x ) ) , \text{RRDB}(x) = x + \beta \cdot \bigl(\mathcal{F}_{\text{RDB}_3} \circ \mathcal{F}_{\text{RDB}_2} \circ \mathcal{F}_{\text{RDB}_1}(x)\bigr), RRDB(x)=x+β⋅(FRDB3∘FRDB2∘FRDB1(x)),
其中 β ≈ 0.2 \beta\approx0.2 β≈0.2 控制残差输出的幅度, ∘ \circ ∘ 表示函数复合。
生成器(Generator)
ESRGAN 的生成器 G θ G_\theta Gθ (参数 θ \theta θ)可分为:
- 头部卷积:初步特征提取
- RRDB 组:堆叠 N 个 RRDB,形成深度表征
- 全局残差:将 RRDB 组输出与头部特征相加
- 上采样:通常用 PixelShuffle 上采样到目标倍数(如 x4)
- 输出卷积:映射到 3 通道 (RGB),得到 I ^ SR = G θ ( I LR ) \hat{I}_{\text{SR}} = G_\theta(I_{\text{LR}}) I^SR=Gθ(ILR)
判别器(Discriminator)
- 判别器 D ϕ D_\phi Dϕ 输出对输入图像的打分;在 ESRGAN 中,多采用相对GAN(如 RaGAN)的方法,使判别器关心真实图与生成图打分的相对差异。
- 网络结构常见:卷积+BN+LeakyReLU+全连接,与 SRGAN 类似,但在损失计算时将真实、生成样本打分做差值。
损失函数
对抗损失:相对判别器(Relativistic Discriminator)
在 RaGAN 设定下,对抗损失不再是单纯的二分类,而是包含相对真度判断:
-
判别器损失 L D \mathcal{L}_D LD:
L D = − [ log ( D Ra ( x r , x f ) ) + log ( 1 − D Ra ( x f , x r ) ) ] , D Ra ( x r , x f ) = σ ( D ϕ ( x r ) − E [ D ϕ ( x f ) ] ) , \begin{aligned} \mathcal{L}_D &= - \Bigl[\log\bigl(D_{\text{Ra}}(x_r, x_f)\bigr) + \log\bigl(1 - D_{\text{Ra}}(x_f, x_r)\bigr)\Bigr], \\ D_{\text{Ra}}(x_r, x_f) &= \sigma\bigl(D_\phi(x_r) - \mathbb{E}[D_\phi(x_f)]\bigr), \end{aligned} LDDRa(xr,xf)=−[log(DRa(xr,xf))+log(1−DRa(xf,xr))],=σ(Dϕ(xr)−E[Dϕ(xf)]),
其中 x r x_r xr 表示真实图, x f x_f xf 表示生成图, σ \sigma σ 为 Sigmoid 函数。 -
生成器的对抗损失 L G adv \mathcal{L}_G^{\text{adv}} LGadv:
L G adv = − [ log ( 1 − D Ra ( x r , x f ) ) + log ( D Ra ( x f , x r ) ) ] . \mathcal{L}_G^{\text{adv}} = - \Bigl[\log\bigl(1 - D_{\text{Ra}}(x_r, x_f)\bigr)+ \log\bigl(D_{\text{Ra}}(x_f, x_r)\bigr)\Bigr]. LGadv=−[log(1−DRa(xr,xf))+log(DRa(xf,xr))].
生成器希望假图相对于真图更像真图(并让真图相对于假图看起来不那么“绝对真”)。
感知损失(Perceptual Loss)
和 SRGAN 类似,ESRGAN 借助 VGG 等预训练网络的某一层或多层特征来定义感知距离。例如:
L
per
=
∑
l
α
l
⋅
1
C
l
H
l
W
l
∑
c
,
i
,
j
(
ϕ
l
(
I
^
SR
)
c
,
i
,
j
−
ϕ
l
(
I
HR
)
c
,
i
,
j
)
2
,
\mathcal{L}_{\text{per}} = \sum_{l} \alpha_l \cdot \frac{1}{C_l H_l W_l} \sum_{c,i,j} \Bigl(\phi_l(\hat{I}_{\text{SR}})_{c,i,j} - \phi_l(I_{\text{HR}})_{c,i,j}\Bigr)^2,
Lper=l∑αl⋅ClHlWl1c,i,j∑(ϕl(I^SR)c,i,j−ϕl(IHR)c,i,j)2,
其中
ϕ
l
\phi_l
ϕl 表示 VGG 第
l
l
l 层的输出特征映射,
α
l
\alpha_l
αl 为权重。通过在多个层度量,可兼顾低层纹理与高层语义结构。
像素损失(可选)
在 ESRGAN 中,也可选地保留一小部分 ℓ 1 \ell_1 ℓ1 或 ℓ 2 \ell_2 ℓ2 像素级损失,帮助网络在早期保持稳定收敛。若记像素损失为 L pix = ∥ I ^ SR − I HR ∥ _ 1 \mathcal{L}_{\text{pix}} = \| \hat{I}_{\text{SR}} - I_{\text{HR}} \|\_1 Lpix=∥I^SR−IHR∥_1 或 ∥ ⋅ ∥ _ 2 \|\cdot\|\_2 ∥⋅∥_2,权重通常较小。
总体损失
生成器最终最小化:
L
G
=
L
per
+
λ
adv
⋅
L
G
adv
+
μ
⋅
L
pix
,
\mathcal{L}_G = \mathcal{L}_{\text{per}} + \lambda_{\text{adv}} \cdot \mathcal{L}_G^{\text{adv}} + \mu \cdot \mathcal{L}_{\text{pix}},
LG=Lper+λadv⋅LGadv+μ⋅Lpix,
其中
λ
adv
,
μ
\lambda_{\text{adv}}, \mu
λadv,μ 为权重超参数。
训练流程
- 数据准备:从 HR 图像下采样得到 LR-HR 对;
- 判别器更新:
- 使用相对判别器损失 L D \mathcal{L}_D LD,并在真实图 (标签=1) 与生成图 (标签=0) 上进行反向传播;
- 生成器更新:
- 计算对抗损失 L G adv \mathcal{L}_G^{\text{adv}} LGadv、感知损失 L per \mathcal{L}_{\text{per}} Lper、可选的像素损失 L pix \mathcal{L}_{\text{pix}} Lpix;
- 加权后得到 L G \mathcal{L}_G LG 并反向传播;
- 反复交替判别器和生成器更新;
- 推断:使用最终训练好的生成器在新 LR 图上执行超分。
与 SRGAN 的主要区别
- 网络主体:ESRGAN 采用RRDB 代替 SRGAN 的残差块 (ResBlock),融合“Dense”与多重残差,去除 BN;
- 相对GAN:SRGAN 使用普通二元 GAN;ESRGAN 使用 Relativistic 方式,判别器关心真图与假图之间的相对差,而非单纯真 / 假分类;
- 更稳定更高频:ESRGAN 在实践中取得了更锐利的纹理和更好的主观感观,一定程度上克服了 SRGAN 纹理不稳定、存在奇异伪影的问题。
详细数学公式说明
以下列举 ESRGAN 训练中的主要数学公式:
-
生成器输出:
I ^ SR = G θ ( I LR ) . \hat{I}_{\text{SR}} = G_\theta(I_{\text{LR}}). I^SR=Gθ(ILR). -
判别器输出:
d r = D ϕ ( I HR ) , d f = D ϕ ( I ^ SR ) . d_r = D_\phi(I_{\text{HR}}), \quad d_f = D_\phi(\hat{I}_{\text{SR}}). dr=Dϕ(IHR),df=Dϕ(I^SR). -
相对判别器 (RaGAN):
D Ra ( x r , x f ) = σ ( D ϕ ( x r ) − E [ D ϕ ( x f ) ] ) , D Ra ( x f , x r ) = σ ( D ϕ ( x f ) − E [ D ϕ ( x r ) ] ) . \begin{aligned} D_{\text{Ra}}(x_r, x_f) &= \sigma\bigl(D_\phi(x_r) - \mathbb{E}[D_\phi(x_f)]\bigr), \\ D_{\text{Ra}}(x_f, x_r) &= \sigma\bigl(D_\phi(x_f) - \mathbb{E}[D_\phi(x_r)]\bigr). \end{aligned} DRa(xr,xf)DRa(xf,xr)=σ(Dϕ(xr)−E[Dϕ(xf)]),=σ(Dϕ(xf)−E[Dϕ(xr)]). -
判别器损失:
L D = − [ log D Ra ( x r , x f ) + log ( 1 − D Ra ( x f , x r ) ) ] . \mathcal{L}_D = -\Bigl[\log D_{\text{Ra}}(x_r, x_f) + \log\bigl(1 - D_{\text{Ra}}(x_f, x_r)\bigr)\Bigr]. LD=−[logDRa(xr,xf)+log(1−DRa(xf,xr))]. -
生成器对抗损失:
L G adv = − [ log ( 1 − D Ra ( x r , x f ) ) + log ( D Ra ( x f , x r ) ) ] . \mathcal{L}_G^{\text{adv}} = - \Bigl[\log\bigl(1 - D_{\text{Ra}}(x_r, x_f)\bigr)+ \log\bigl(D_{\text{Ra}}(x_f, x_r)\bigr)\Bigr]. LGadv=−[log(1−DRa(xr,xf))+log(DRa(xf,xr))]. -
感知损失:
L per = ∑ l α l ⋅ ∥ ϕ l ( I ^ SR ) − ϕ l ( I HR ) ∥ _ 2 2 . \mathcal{L}_{\text{per}} = \sum_{l} \alpha_l \cdot \|\phi_l(\hat{I}_{\text{SR}}) - \phi_l(I_{\text{HR}})\|\_2^2. Lper=l∑αl⋅∥ϕl(I^SR)−ϕl(IHR)∥_22. -
像素损失 (可选):
L pix = ∥ I ^ SR − I HR ∥ _ 1 或 ∥ ⋅ ∥ _ 2. \mathcal{L}_{\text{pix}} = \|\hat{I}_{\text{SR}} - I_{\text{HR}}\|\_1 \quad\text{或}\quad \|\cdot\|\_2. Lpix=∥I^SR−IHR∥_1或∥⋅∥_2. -
生成器总损失:
L G = L per + λ adv ⋅ L G adv + μ ⋅ L pix . \mathcal{L}_G = \mathcal{L}_{\text{per}} + \lambda_{\text{adv}} \cdot \mathcal{L}_G^{\text{adv}} + \mu \cdot \mathcal{L}_{\text{pix}}. LG=Lper+λadv⋅LGadv+μ⋅Lpix.
代码示例
下面给出一个简化 ESRGAN 的 PyTorch 实现,用于演示主要结构(RRDB)与相对判别器思想。该示例仅供教学参考,实际工程中需补充更多细节与优化策略。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn.functional as F
# =========== 1. 数据集 (示例) ===========
class SimpleSRDataset(Dataset):
def __init__(self, lr_images, hr_images, transform=None):
self.lr_images = lr_images
self.hr_images = hr_images
self.transform = transform
def __len__(self):
return len(self.lr_images)
def __getitem__(self, idx):
lr_img = self.lr_images[idx]
hr_img = self.hr_images[idx]
if self.transform:
lr_img = self.transform(lr_img)
hr_img = self.transform(hr_img)
return lr_img, hr_img
# =========== 2. RDB (Residual Dense Block) ===========
class RDB(nn.Module):
def __init__(self, in_channels, growth_channels=32, num_layers=5):
super(RDB, self).__init__()
self.num_layers = num_layers
self.in_channels = in_channels
self.growth = growth_channels
self.convs = nn.ModuleList()
for i in range(num_layers):
self.convs.append(nn.Conv2d(in_channels + i*growth_channels, growth_channels, 3, 1, 1))
self.lff = nn.Conv2d(in_channels + num_layers*growth_channels, in_channels, 1, 1, 0)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
features = [x]
for i in range(self.num_layers):
out = torch.cat(features, dim=1)
out = self.convs[i](out)
out = self.relu(out)
features.append(out)
out = torch.cat(features, dim=1)
out = self.lff(out)
return x + out
# =========== 3. RRDB (Residual in Residual Dense Block) ===========
class RRDB(nn.Module):
def __init__(self, in_channels, growth_channels=32, num_layers=5, scale=0.2):
super(RRDB, self).__init__()
self.rdb1 = RDB(in_channels, growth_channels, num_layers)
self.rdb2 = RDB(in_channels, growth_channels, num_layers)
self.rdb3 = RDB(in_channels, growth_channels, num_layers)
self.scale = scale
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return x + out * self.scale
# =========== 4. 生成器 (ESRGAN) ===========
class ESRGANGenerator(nn.Module):
def __init__(self, in_channels=3, out_channels=3, num_feats=64, n_rrdb=23, scale_factor=4):
super(ESRGANGenerator, self).__init__()
# 1) 头部特征
self.conv_head = nn.Conv2d(in_channels, num_feats, kernel_size=3, padding=1)
# 2) RRDB 干
blocks = []
for _ in range(n_rrdb):
blocks.append(RRDB(num_feats, 32, 5, scale=0.2))
self.trunk = nn.Sequential(*blocks)
self.conv_trunk = nn.Conv2d(num_feats, num_feats, 3, 1, 1)
# 3) 上采样
self.upsample = []
for _ in range(int(scale_factor/2)): # 例如 scale_factor=4 => 2次 x2
self.upsample.append(nn.Conv2d(num_feats, num_feats*4, 3, 1, 1))
self.upsample.append(nn.PixelShuffle(2))
self.upsample.append(nn.LeakyReLU(0.2, inplace=True))
self.upsample = nn.Sequential(*self.upsample)
# 4) 输出层
self.conv_last = nn.Conv2d(num_feats, out_channels, 3, 1, 1)
def forward(self, x):
feat = self.conv_head(x)
trunk = self.trunk(feat)
trunk = self.conv_trunk(trunk)
trunk += feat # 全局残差
out = self.upsample(trunk)
out = self.conv_last(out)
return out
# =========== 5. 判别器 (相对GAN示例, 简化) ===========
class ESRGANDiscriminator(nn.Module):
def __init__(self):
super(ESRGANDiscriminator, self).__init__()
# 与SRGAN类似结构,但计算损失时采用相对式
layers = []
in_c = 3
layers += [
nn.Conv2d(in_c, 64, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
]
self.main = nn.Sequential(*layers)
self.fc = nn.Linear(512*8*8, 1) # 假定输入128x128 (仅演示)
def forward(self, x):
out = self.main(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
# =========== 6. VGG 感知损失 ===========
class VGGPerceptualLoss(nn.Module):
def __init__(self, layer_idx=35):
super(VGGPerceptualLoss, self).__init__()
vgg = models.vgg19(pretrained=True).features
self.slice = nn.Sequential(*list(vgg.children())[:layer_idx])
for p in self.slice.parameters():
p.requires_grad = False
def forward(self, sr, hr):
sr_feat = self.slice(sr)
hr_feat = self.slice(hr)
return F.mse_loss(sr_feat, hr_feat)
# =========== 7. ESRGAN 训练主循环 (示例) ===========
def train_esrgan(lr_images, hr_images, epochs=10, batch_size=4, learning_rate=1e-4):
dataset = SimpleSRDataset(lr_images, hr_images, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = ESRGANGenerator().to(device)
D = ESRGANDiscriminator().to(device)
# 感知损失
vgg_loss = VGGPerceptualLoss().to(device)
bce_loss = nn.BCEWithLogitsLoss() # 对抗损失
optimizer_g = optim.Adam(G.parameters(), lr=learning_rate, betas=(0.9,0.999))
optimizer_d = optim.Adam(D.parameters(), lr=learning_rate, betas=(0.9,0.999))
for epoch in range(epochs):
G.train()
D.train()
total_g_loss = 0.0
total_d_loss = 0.0
for lr_batch, hr_batch in dataloader:
lr_batch = lr_batch.to(device)
hr_batch = hr_batch.to(device)
# ==== 1) 判别器更新 ====
# 生成SR
sr_out = G(lr_batch)
d_real = D(hr_batch)
d_fake = D(sr_out.detach())
# 相对判别器 - 仅演示写法
real_label = torch.ones_like(d_real)
fake_label = torch.zeros_like(d_fake)
d_loss_real = bce_loss(d_real - torch.mean(d_fake), real_label)
d_loss_fake = bce_loss(d_fake - torch.mean(d_real), fake_label)
d_loss = (d_loss_real + d_loss_fake) / 2
optimizer_d.zero_grad()
d_loss.backward()
optimizer_d.step()
# ==== 2) 生成器更新 ====
sr_out = G(lr_batch)
d_real2 = D(hr_batch)
d_fake2 = D(sr_out)
# 对抗损失
g_adv_loss = bce_loss(d_real2 - torch.mean(d_fake2), fake_label) \
+ bce_loss(d_fake2 - torch.mean(d_real2), real_label)
g_adv_loss /= 2
# 感知损失
g_per_loss = vgg_loss(sr_out, hr_batch)
# (可选)像素损失
# g_pix_loss = F.l1_loss(sr_out, hr_batch)
# 总损失
lambda_adv = 1e-3
g_loss = g_per_loss + lambda_adv*g_adv_loss
optimizer_g.zero_grad()
g_loss.backward()
optimizer_g.step()
total_d_loss += d_loss.item()
total_g_loss += g_loss.item()
avg_d = total_d_loss / len(dataloader)
avg_g = total_g_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs} | D_loss: {avg_d:.4f} | G_loss: {avg_g:.4f}")
return G, D
# =========== 8. 推断 (Inference) ===========
def inference_esrgan(generator, lr_image):
generator.eval()
transform = transforms.ToTensor()
lr_tensor = transform(lr_image).unsqueeze(0)
device = next(generator.parameters()).device
lr_tensor = lr_tensor.to(device)
with torch.no_grad():
sr_tensor = generator(lr_tensor)
sr_tensor = sr_tensor.clamp(0,1)
return sr_tensor.squeeze(0).cpu()
代码简要解读
-
RDB (Residual Dense Block)
- 在一个块中,每层卷积的输入包含所有先前层输出的拼接(Dense 连接),并在输出端用 1×1 卷积特征融合,再加上输入形成局部残差;
- 相当于集合了 DenseNet 思想与局部残差设计。
-
RRDB (Residual in Residual Dense Block)
- 由 3 个 RDB 串行,并在最外面加一条残差通路,输出乘以一个缩放因子 (默认 0.2),从而形成“残差中的残差”;
- 这种设计可以在大深度网络中有效控制梯度,避免训练不稳定,同时利用多重残差与密集连接提升细节还原。
-
ESRGANGenerator
- conv_head:将输入 LR 映射到较高维度;
- RRDB 干:堆叠若干 RRDB,用于深度特征提取;
- conv_trunk 与输入特征相加形成全局残差;
- 上采样:像素级放大 (PixelShuffle);
- conv_last:输出最终的超分辨率图像。
-
ESRGANDiscriminator
- 多层卷积 (stride=2) + BatchNorm + LeakyReLU,用全连接输出一个分数;
- 在训练时通过相对GAN公式做真 / 假 (相对差) 的对抗判别。
-
VGGPerceptualLoss
- 利用预训练 VGG19 的高层特征做感知损失,用于度量 SR 与 HR 在纹理、结构上的差异,保证视觉质量。
-
训练循环
- 判别器更新:对真实 HR (打算输出大)、假 SR (打算输出小) 的相对打分进行 BCELoss;
- 生成器更新:综合对抗损失 (相对GAN) + 感知损失 (VGG),可再加像素损失,通过反向传播更新生成器参数;
- 反复迭代直到收敛,网络即可生成更锐利逼真的 SR 图像。
-
推断 (Inference)
- 只需加载训练好的生成器,对 LR 图做前向传播,得到最终高分辨率结果。