【IQA技术专题】NIQE代码讲解

部署运行你感兴趣的模型镜像

本文是对NIQE图像质量评价指标的代码解读,原文解读请看NIQE文章讲解
本文的代码来源于IQA-Pytorch工程。

1、原文概要

NIQE实现了无参考的图像质量评价指标,可以有效地对图像的感知(Fidelity)质量进行评估。本文提出了一种完全盲目的图像质量评估(IQA)模型 —— 自然图像质量评估器(NIQE),它基于自然场景统计(NSS)模型,通过提取自然图像的统计特征构建多元高斯(MVG)模型,再计算测试图像特征与该模型的距离来评估质量。与需训练的现有模型不同,NIQE 无需依赖失真图像或人类主观评分,在 LIVE IQA 数据库上的性能优于 PSNR、SSIM 等全参考模型,且与当时顶级的无参考模型BRISQUE相当。
评估方法大致可以分为以下4个步骤,用下图表示。
在这里插入图片描述

  1. 提取quality aware的特征,这个特征来源于空间域自然场景统计模型(NSS模型)。
  2. 在图中选取合适的块来进行评估,作者设计了一种方法来选取图像中更能代表图像质量的块,而非全部使用。
  3. 从补丁中提取到的NSS特征使用零均值广义高斯分布(GGD)来进行拟合,拟合的分布参数作者会进一步送给多元高斯模型拟合使用。
  4. 使用步骤3的分布参数特征fit in一个multivariate Gaussian (MVG) model 多元高斯分布。

这样是完成了一次MVG模型的拟合。

为进一步得到NIQE指标,作者会首先选取来自Flickr无版权数据与Berkeley图像分割数据库的125张自然图像,分辨率范围480×320至1280×720,进行一次上述的过程得到一个标准的多元高斯分布。在测试时,只需要对待评估图像重复一次上述的过程并计算两个分布之间的马氏距离即可。马氏距离计算公式如下:
D ( ν 1 , ν 2 , ∑ 1 , ∑ 2 ) = ( ν 1 − ν 2 ) T ( ∑ 1 + ∑ 2 2 ) − 1 ( ν 1 − ν 2 ) \begin{aligned} D\left(\nu_{1}, \nu_{2}, \sum_{1}, \sum_{2}\right) &= \sqrt{ \left( \nu_{1} - \nu_{2} \right)^{T} \left( \frac{\sum_{1} + \sum_{2}}{2} \right)^{-1} \left( \nu_{1} - \nu_{2} \right) } \end{aligned} D(ν1,ν2,1,2)=(ν1ν2)T(21+2)1(ν1ν2) 其中 v 1 v_1 v1 v 2 v_2 v2以及协方差矩阵是前面的多元高斯模型的参数。

2、代码结构

代码实现位于pyiqa/archs/niqe_arch.py中
在这里插入图片描述

3 、核心代码模块

NIQE

这个类实现了整体的参数传入与函数调用。

@ARCH_REGISTRY.register()
class NIQE(torch.nn.Module):
    r"""Args:
        - channels (int): Number of processed channel.
        - test_y_channel (bool): whether to use y channel on ycbcr.
        - crop_border (int): Cropped pixels in each edge of an image. These
        pixels are not involved in the metric calculation.
        - pretrained_model_path (str): The pretrained model path.
    References:
        Mittal, Anish, Rajiv Soundararajan, and Alan C. Bovik.
        "Making a “completely blind” image quality analyzer."
        IEEE Signal Processing Letters (SPL) 20.3 (2012): 209-212.
    """

    def __init__(
        self,
        channels: int = 1,
        test_y_channel: bool = True,
        color_space: str = 'yiq',
        crop_border: int = 0,
        version: str = 'original',
        pretrained_model_path: str = None,
    ) -> None:
        super(NIQE, self).__init__()
        self.channels = channels
        self.test_y_channel = test_y_channel
        self.color_space = color_space
        self.crop_border = crop_border
        if pretrained_model_path is not None:
            pretrained_model_path = pretrained_model_path
        elif version == 'original':
            pretrained_model_path = load_file_from_url(default_model_urls['url'])
        elif version == 'matlab':
            pretrained_model_path = load_file_from_url(
                default_model_urls['niqe_matlab']
            )

        # load model parameters
        params = scipy.io.loadmat(pretrained_model_path)
        mu_pris_param = np.ravel(params['mu_prisparam'])
        cov_pris_param = params['cov_prisparam']
        self.mu_pris_param = torch.from_numpy(mu_pris_param)
        self.cov_pris_param = torch.from_numpy(cov_pris_param)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Computation of NIQE metric.
        Input:
            x: An input tensor. Shape :math:`(N, C, H, W)`.
        Output:
            score (tensor): results of ilniqe metric, should be a positive real number. Shape :math:`(N, 1)`.
        """
        score = calculate_niqe(
            x,
            self.crop_border,
            self.test_y_channel,
            self.color_space,
            self.mu_pris_param,
            self.cov_pris_param,
        )
        return score

调用calculate_niqe函数得到结果。这里初始化会直接导入预训练的权重,即我们前面讲到的使用标准图像拟合的分布,保存在mu_pris和conv_pris中。

calculate_niqe 函数

实际计算的代码。

def calculate_niqe(
    img: torch.Tensor,
    crop_border: int = 0,
    test_y_channel: bool = True,
    color_space: str = 'yiq',
    mu_pris_param: torch.Tensor = None,
    cov_pris_param: torch.Tensor = None,
    **kwargs,
) -> torch.Tensor:
    """Calculate NIQE (Natural Image Quality Evaluator) metric.
    Args:
        img (Tensor): Input image whose quality needs to be computed.
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the metric calculation.
        test_y_channel (Bool): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
        pretrained_model_path (str): The pretrained model path.
    Returns:
        Tensor: NIQE result.
    """

    # NIQE only support gray image
    if img.shape[1] == 3:
        img = to_y_channel(img, 255, color_space)
    elif img.shape[1] == 1:
        img = img * 255

    img = diff_round(img)
    img = img.to(torch.float64)

    mu_pris_param = mu_pris_param.to(img).repeat(img.size(0), 1)
    cov_pris_param = cov_pris_param.to(img).repeat(img.size(0), 1, 1)

    if crop_border != 0:
        img = img[..., crop_border:-crop_border, crop_border:-crop_border]

    niqe_result = niqe(img, mu_pris_param, cov_pris_param)

    return niqe_result

又封装了一层,实际在niqe中发生。

niqe 函数

实际计算的代码。

def niqe(
    img: torch.Tensor,
    mu_pris_param: torch.Tensor,
    cov_pris_param: torch.Tensor,
    block_size_h: int = 96,
    block_size_w: int = 96,
) -> torch.Tensor:
    assert img.ndim == 4, (
        'Input image must be a gray or Y (of YCbCr) image with shape (b, c, h, w).'
    )
    # crop image
    b, c, h, w = img.shape
    num_block_h = math.floor(h / block_size_h)
    num_block_w = math.floor(w / block_size_w)
    img = img[..., 0 : num_block_h * block_size_h, 0 : num_block_w * block_size_w]

    distparam = []  # dist param is actually the multiscale features
    for scale in (1, 2):  # perform on two scales (1, 2)
        img_normalized = normalize_img_with_gauss(img, padding='replicate')

        distparam.append(
            blockproc(
                img_normalized,
                [block_size_h // scale, block_size_w // scale],
                fun=compute_feature,
            )
        )

        if scale == 1:
            img = imresize(img / 255.0, scale=0.5, antialiasing=True)
            img = img * 255.0

    distparam = torch.cat(distparam, -1)

    # fit a MVG (multivariate Gaussian) model to distorted patch features
    mu_distparam = nanmean(distparam, dim=1)
    cov_distparam = nancov(distparam)

    # compute niqe quality, Eq. 10 in the paper
    invcov_param = torch.linalg.pinv((cov_pris_param + cov_distparam) / 2)
    diff = (mu_pris_param - mu_distparam).unsqueeze(1)
    quality = torch.bmm(torch.bmm(diff, invcov_param), diff.transpose(1, 2)).squeeze()

    quality = torch.sqrt(quality)
    return quality

分为以下几步:

  1. 图像裁剪:
b, c, h, w = img.shape
num_block_h = math.floor(h / block_size_h)
num_block_w = math.floor(w / block_size_w)
img = img[..., 0 : num_block_h * block_size_h, 0 : num_block_w * block_size_w]

将输入图像裁剪为能被 block_size 整除的尺寸,避免边缘无法形成完整图像块的问题。

  1. 多尺度特征提取:
distparam = []  # 存储多尺度特征
for scale in (1, 2):  # 在两个尺度(1x和0.5x)上提取特征
    # 步骤1:高斯归一化预处理
    img_normalized = normalize_img_with_gauss(img, padding='replicate')
    
    # 步骤2:分块计算特征
    distparam.append(
        blockproc(
            img_normalized,
            [block_size_h // scale, block_size_w // scale],  # 尺度缩放时块大小同步缩小
            fun=compute_feature,  # 计算每个块的统计特征
        )
    )
    
    # 步骤3:下采样到0.5x尺度(仅在第一个尺度后执行)
    if scale == 1:
        img = imresize(img / 255.0, scale=0.5, antialiasing=True)  # 下采样
        img = img * 255.0  # 恢复像素值范围

在两个尺度(原始尺度和 0.5 倍下采样尺度)上提取特征,模拟人类视觉系统的多尺度感知特性。normalize_img_with_gauss:对图像进行高斯滤波归一化(去除局部均值和对比度),这对应于NSS模型;blockproc:将图像分块,对每个块调用 compute_feature 提取统计特征(如广义高斯分布 GGD、非对称广义高斯分布 AGGD 的参数),对应论文中的分块特征提取部分。

  1. 特征拼接与 MVG 模型拟合:
# 拼接两个尺度的特征
distparam = torch.cat(distparam, -1)

# 拟合失真图像的多元高斯(MVG)模型
mu_distparam = nanmean(distparam, dim=1)  # 特征均值向量
cov_distparam = nancov(distparam)        # 特征协方差矩阵

将多尺度特征拼接为一个特征向量,然后计算失真图像特征的均值和协方差,拟合 MVG 模型(nanmean 和 nancov 用于处理可能的 NaN 值(避免异常值影响统计特性))。

  1. 计算 NIQE 质量分数:
# 计算NIQE分数(论文公式10)
invcov_param = torch.linalg.pinv((cov_pris_param + cov_distparam) / 2)  # 平均协方差的伪逆
diff = (mu_pris_param - mu_distparam).unsqueeze(1)  # 均值差向量(扩展维度用于矩阵乘法)
# 马氏距离计算:sqrt[(μ1-μ2)^T * Σ_avg^{-1} * (μ1-μ2)]
quality = torch.bmm(torch.bmm(diff, invcov_param), diff.transpose(1, 2)).squeeze()
quality = torch.sqrt(quality)
return quality

利用概要中提到的niqe公式计算马氏距离。

3、总结

代码实现核心的部分讲解完毕,NIQE作为一个传统的无参考IQA指标,使用比较广泛。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值