Open-FWI代码解析(3)

本文档详细介绍了如何在PyTorch中实现结构相似性(SSIM)评价指标,包括gaussian函数、create_window函数以及SSIM计算方法。重点讲解了高斯滤波器的应用和在图像处理中的作用,以及如何在PyTorch模块中创建和使用SSIM类进行图片比较。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

1.pytorch_ssim文件

        1.1 gaussian函数

        1.2 create_window函数

        1.3 SSIM评价指标函数

        1.4 SSIM评价指标类

2. 总结


1.pytorch_ssim文件

        1.1 gaussian函数

def gaussian(window_size, sigma):
    '''
    形成大小为window_size,标准差为sigma的一串高斯序列
    :param window_size:
    :param sigma:
    :return:
    '''
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

        1.2 create_window函数

def create_window(window_size, channel):
    '''

    :param window_size: 矩阵大小
    :param channel: 通道大小
    :return: (channel,1,window_size,window_size)大小的高斯滤波器
    '''
    # .unsqueeze(1)是插入第二个维度
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    # _1D_window.mm矩阵乘法
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    # .contiguous()的作用是返回一个具有相同数据但存储方式不同的新张量。
    # 这通常用于确保张量在内存中是连续存储的,以便后续的计算操作。
    # 将处理过的二维高斯滤波器包装在Variable对象中,并将其作为函数的返回值。
    # 在较新版本的PyTorch中,Variable对象已经被整合到了torch.Tensor中,所以在实际应用中可能不再需要显式地使用Variable。
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

        这两部分可以看做生成一个高斯滤波器的代码, 至于什么是高斯滤波器, 它有什么作用, 参考这篇博客. 简而言之就是利用卷积运算, 对图像进行平滑处理, 减少高频部分, 保留低频部分. 由于高频部分(细节)部分消失, 图像就会变得模糊. 高斯滤波器对于抑制 高斯噪声 (服从正态分布的噪声) 非常有效. 这里应该是考虑到图像生成过程中, 或多或少会有噪声, 所以加入高斯滤波器减少噪声. 也可能是减少细节部分, 对两个图片进行综合的评价(因为细节会拉大评价差距?)

        1.3 SSIM评价指标函数

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    '''
    这段代码实现了SSIM(结构相似度)指标的计算,用于度量两幅图像之间的结构相似性。
    :param img1:图像1,大小(batch_size, channel, height, width)
    :param img2:图像2,大小(batch_size, channel, height, width)
    :param window::高斯滤波器,大小(channel, 1, window_size, window_size)
    :param window_size:高斯滤波器大小 window_size
    :param channel:通道数
    :param size_average:一个布尔值,表示是否对SSIM值进行平均。
    如果为True,则返回所有像素的均值作为SSIM值;如果为False,则返回每张图像的SSIM值的均值
    :return:
    '''
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

       什么是SSIM指标参考这篇---博客

        这里再补充一点, 对于 size_average 参数, 举个例子, 若比较的图像为img1 = (10,3,5,5), img2 =(10,3,5,5)的大小,选定size_average = true, 则返回这10个图片的SSIM值的平均值; 若选定size_average = false, 则返回这10个图片对应比较的SSIM值, 即为一个大小为10的序列. 

        除此之外, mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)

        这里group = channel是将img1分为group个通道, 再与window做卷积(因为window = (channel, 1, window_size, window_size)). 这部分卷积的维度大小一定要弄明白, 可以多尝试几次.

        1.4 SSIM评价指标类

class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        '''
        :param window_size: 高斯滤波器大小
        :param size_average: 是否返回所有图片SSIM指标
        '''
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        # 得到img1的通道
        (_, channel, _, _) = img1.size()
        # 若图片的通道和滤波器的通道一致,并且图片和滤波器的数据类型一致,则保留window
        # 否则重新创建一个和图片通道一致的window
        # 修改window的设备,数据类型
        # 最后返回img1,img2的SSIM指标
        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
    # 和class SSIM相似,不过多解释
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

2. 总结

        这部分文件主要是对img1, img2计算SSIM指标, 感觉比较简单, 主要是把SSIM的计算公式用python代码实现即可.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值