文章目录
文章目录
00 写在前面
这段代码定义了一个名为 gvloss
的损失函数,用于计算输出图像和目标图像在梯度方向上的差异,特别适用于图像生成或图像恢复任务。
01 程序
# Sobel kernel for the gradient map calculation
self.kernel_x = torch.FloatTensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).cuda()
self.kernel_y = torch.FloatTensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).unsqueeze(0).unsqueeze(0).cuda()
# operation for unfolding image into non overlapping patches
self.unfold = torch.nn.Unfold(kernel_size=(self.patch_size, self.patch_size), stride=self.patch_size)
def gvloss(self, output, target):
# calculation of the gradient maps of x and y directions
gx_target = F.conv2d(target, self.kernel_x, stride=1, padding=1)
gy_target = F.conv2d(target, self.kernel_y, stride=1, padding=1)
gx_output = F.conv2d(output, self.kernel_x, stride=1, padding=1)
gy_output = F.conv2d(output, self.kernel_y, stride=1, padding=1)
# unfolding image to patches
gx_target_patches = self.unfold(gx_target)
gy_target_patches = self.unfold(gy_target)
gx_output_patches = self.unfold(gx_output)
gy_output_patches = self.unfold(gy_output)
# calculation of variance of each patch
var_target_x = torch.var(gx_target_patches, dim=1)
var_output_x = torch.var(gx_output_patches, dim=1)
var_target_y = torch.var(gy_target_patches, dim=1)
var_output_y = torch.var(gy_output_patches, dim=1)
# loss function as a MSE between variances of patches extracted from gradient maps
gradvar_loss = F.mse_loss(var_target_x, var_output_x) + F.mse_loss(var_target_y, var_output_y)
return gradvar_loss
02 代码功能分解
1. 梯度计算
gx_target = F.conv2d(target, self.kernel_x, stride=1, padding=1)
gy_target = F.conv2d(target, self.kernel_y, stride=1, padding=1)
gx_output = F.conv2d(output, self.kernel_x, stride=1, padding=1)
gy_output = F.conv2d(output, self.kernel_y, stride=1, padding=1)
F.conv2d()
:使用卷积操作计算图像梯度。self.kernel_x
和self.kernel_y
:分别为检测水平方向和垂直方向梯度的 Sobel 滤波核。target
和output
:分别为真实目标图像和模型生成的图像。- 结果
gx_target
和gy_target
是目标图像的水平和垂直梯度图,gx_output
和gy_output
是输出图像的水平和垂直梯度图。
2. 图像展开为补丁
gx_target_patches = self.unfold(gx_target)
gy_target_patches = self.unfold(gy_target)
gx_output_patches = self.unfold(gx_output)
gy_output_patches = self.unfold(gy_output)
self.unfold
:将图像展开为(B, C * kernel_size^2, L)
的补丁表示,其中:B
是 batch size。C
是通道数。kernel_size^2
是每个补丁的大小。L
是补丁的数量。
- 作用:将梯度图分割成多个局部补丁,以局部区域为单位进行后续的方差计算。
3. 计算每个补丁的方差
var_target_x = torch.var(gx_target_patches, dim=1)
var_output_x = torch.var(gx_output_patches, dim=1)
var_target_y = torch.var(gy_target_patches, dim=1)
var_output_y = torch.var(gy_output_patches, dim=1)
torch.var
:计算每个补丁在通道维度上的方差。- 结果
var_target_x
和var_output_x
分别表示目标图像和输出图像在水平梯度方向上的每个补丁的方差,var_target_y
和var_output_y
类似但针对垂直梯度方向。
4. 计算损失函数
gradvar_loss = F.mse_loss(var_target_x, var_output_x) + F.mse_loss(var_target_y, var_output_y)
F.mse_loss
:计算均方误差(MSE)损失。- 作用:比较目标图像和输出图像在水平和垂直梯度方向上的方差差异,损失越小表示梯度分布越接近。
5. 返回损失值
return gradvar_loss
- 返回计算得到的梯度方差损失值。
整体功能
- 梯度敏感性:通过计算图像的梯度图,捕捉图像的边缘和纹理信息。
- 局部特征损失:将图像分成补丁,计算每个补丁的梯度方差差异,关注图像的局部结构。
- 提升生成质量:用于图像生成或超分辨率任务,帮助生成更接近真实图像的边缘和纹理细节。
应用场景
- 图像生成:如 GAN、VAE 等生成模型的训练,作为对抗损失或重建损失的补充。
- 图像恢复:如超分辨率、去噪、去模糊等任务,帮助恢复图像的高频细节。
- 风格迁移:促使生成图像保留目标风格的纹理和边缘特征。
关键参数
self.kernel_x
和self.kernel_y
:Sobel 滤波核,用于检测水平和垂直梯度。self.unfold
:用于将图像展开为补丁,补丁大小会影响局部特征的粒度。self.criterion
:损失函数,这里是均方误差(MSE)。
注意事项
- 计算成本:展开为补丁和计算方差会增加计算量,可能降低训练速度。
- 适用场景:更适用于需要保留图像结构和纹理的任务,对于全局语义信息的依赖较小。
02 论文下载
GRADIENTVARIANCELOSSFORSTRUCTURE-ENHANCEDIMAGE
SUPER-RES
OLUTION
https://github.com/lusinlu/gradient-variance-loss