【Python/Pytorch - 网络模型】-- 手把手搭建GVloss

在这里插入图片描述

文章目录

00 写在前面

这段代码定义了一个名为 gvloss 的损失函数,用于计算输出图像和目标图像在梯度方向上的差异,特别适用于图像生成或图像恢复任务。

01 程序

        # Sobel kernel for the gradient map calculation
        self.kernel_x = torch.FloatTensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).cuda()
        self.kernel_y = torch.FloatTensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).unsqueeze(0).unsqueeze(0).cuda()
        # operation for unfolding image into non overlapping patches
        self.unfold = torch.nn.Unfold(kernel_size=(self.patch_size, self.patch_size), stride=self.patch_size)

def gvloss(self, output, target):

        # calculation of the gradient maps of x and y directions
        gx_target = F.conv2d(target, self.kernel_x, stride=1, padding=1)
        gy_target = F.conv2d(target, self.kernel_y, stride=1, padding=1)
        gx_output = F.conv2d(output, self.kernel_x, stride=1, padding=1)
        gy_output = F.conv2d(output, self.kernel_y, stride=1, padding=1)

        # unfolding image to patches
        gx_target_patches = self.unfold(gx_target)
        gy_target_patches = self.unfold(gy_target)
        gx_output_patches = self.unfold(gx_output)
        gy_output_patches = self.unfold(gy_output)

        # calculation of variance of each patch
        var_target_x = torch.var(gx_target_patches, dim=1)
        var_output_x = torch.var(gx_output_patches, dim=1)
        var_target_y = torch.var(gy_target_patches, dim=1)
        var_output_y = torch.var(gy_output_patches, dim=1)

        # loss function as a MSE between variances of patches extracted from gradient maps
        gradvar_loss = F.mse_loss(var_target_x, var_output_x) + F.mse_loss(var_target_y, var_output_y)

        return gradvar_loss

02 代码功能分解

1. 梯度计算

gx_target = F.conv2d(target, self.kernel_x, stride=1, padding=1)
gy_target = F.conv2d(target, self.kernel_y, stride=1, padding=1)
gx_output = F.conv2d(output, self.kernel_x, stride=1, padding=1)
gy_output = F.conv2d(output, self.kernel_y, stride=1, padding=1)
  • F.conv2d():使用卷积操作计算图像梯度。
    • self.kernel_xself.kernel_y:分别为检测水平方向和垂直方向梯度的 Sobel 滤波核。
    • targetoutput:分别为真实目标图像和模型生成的图像。
    • 结果 gx_targetgy_target 是目标图像的水平和垂直梯度图,gx_outputgy_output 是输出图像的水平和垂直梯度图。

2. 图像展开为补丁

gx_target_patches = self.unfold(gx_target)
gy_target_patches = self.unfold(gy_target)
gx_output_patches = self.unfold(gx_output)
gy_output_patches = self.unfold(gy_output)
  • self.unfold:将图像展开为 (B, C * kernel_size^2, L) 的补丁表示,其中:
    • B 是 batch size。
    • C 是通道数。
    • kernel_size^2 是每个补丁的大小。
    • L 是补丁的数量。
  • 作用:将梯度图分割成多个局部补丁,以局部区域为单位进行后续的方差计算。

3. 计算每个补丁的方差

var_target_x = torch.var(gx_target_patches, dim=1)
var_output_x = torch.var(gx_output_patches, dim=1)
var_target_y = torch.var(gy_target_patches, dim=1)
var_output_y = torch.var(gy_output_patches, dim=1)
  • torch.var:计算每个补丁在通道维度上的方差。
  • 结果 var_target_xvar_output_x 分别表示目标图像和输出图像在水平梯度方向上的每个补丁的方差,var_target_yvar_output_y 类似但针对垂直梯度方向。

4. 计算损失函数

gradvar_loss = F.mse_loss(var_target_x, var_output_x) + F.mse_loss(var_target_y, var_output_y)
  • F.mse_loss:计算均方误差(MSE)损失。
  • 作用:比较目标图像和输出图像在水平和垂直梯度方向上的方差差异,损失越小表示梯度分布越接近。

5. 返回损失值

return gradvar_loss
  • 返回计算得到的梯度方差损失值。

整体功能

  • 梯度敏感性:通过计算图像的梯度图,捕捉图像的边缘和纹理信息。
  • 局部特征损失:将图像分成补丁,计算每个补丁的梯度方差差异,关注图像的局部结构。
  • 提升生成质量:用于图像生成或超分辨率任务,帮助生成更接近真实图像的边缘和纹理细节。

应用场景

  • 图像生成:如 GAN、VAE 等生成模型的训练,作为对抗损失或重建损失的补充。
  • 图像恢复:如超分辨率、去噪、去模糊等任务,帮助恢复图像的高频细节。
  • 风格迁移:促使生成图像保留目标风格的纹理和边缘特征。

关键参数

  • self.kernel_xself.kernel_y:Sobel 滤波核,用于检测水平和垂直梯度。
  • self.unfold:用于将图像展开为补丁,补丁大小会影响局部特征的粒度。
  • self.criterion:损失函数,这里是均方误差(MSE)。

注意事项

  • 计算成本:展开为补丁和计算方差会增加计算量,可能降低训练速度。
  • 适用场景:更适用于需要保留图像结构和纹理的任务,对于全局语义信息的依赖较小。

02 论文下载

GRADIENTVARIANCELOSSFORSTRUCTURE-ENHANCEDIMAGE
SUPER-RES
OLUTION

https://github.com/lusinlu/gradient-variance-loss

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

电科_银尘

你的鼓励将是我最大的创作动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值