【IQA技术专题】MUSIQ代码讲解

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

本文是对MUSIQ图像质量评价指标的代码解读,原文解读请看MUSIQ文章讲解
本文的代码来源于IQA-Pytorch工程。

1、原文概要

该文章设计的指标在很多图像复原算法的测评中会看到,MUSIQ的特点在于其可以解决卷积神经网络(CNN)在图像质量评估(IQA)中因固定输入尺寸限制导致的图像质量失真问题。模型结构如下图所示:
在这里插入图片描述

MUSIQ 通过多尺度表示、空间嵌入和尺度嵌入处理可变尺寸输入,由三部分组成:

  1. 多尺度补丁嵌入:生成原生分辨率图像及其纵横比保持(ARP)缩放变体的补丁序列;就是最下面一行,多个不同尺寸的保持宽高比的图像序列,作者总共生成了k个,实际使用时k=3,即有3个不同大小相同宽高比的图像被打成序列输入进网络。
  2. 哈希基 2D 空间嵌入:编码补丁的 2D 空间位置;此模块是为了优化原有VIT的位置编码功能,由于我们进行了resize调整了绝对位置,如果还是用绝对位置编码应用到各个不同分辨率输入上显然是不合理的,此为一个hash-based的2D编码,可以解决这个问题。
  3. 尺度嵌入(SCE):区分不同尺度的补丁;此为了将我们resize的倍数作为信息也编码进去,因为这个也是一个重要的先验信息。

最终通过 Transformer 编码器聚合信息,输出质量评分。

2、代码结构

代码实现位于pyiqa/archs/musiq_arch.py中
在这里插入图片描述

3 、核心代码模块

MUSIQ

这个类实现了整体的参数传入与函数调用。

@ARCH_REGISTRY.register()
class MUSIQ(nn.Module):
    def __init__(
        self,
        patch_size=32,
        num_class=1,
        hidden_size=384,
        mlp_dim=1152,
        attention_dropout_rate=0.0,
        dropout_rate=0,
        num_heads=6,
        num_layers=14,
        num_scales=3,
        spatial_pos_grid_size=10,
        use_scale_emb=True,
        use_sinusoid_pos_emb=False,
        pretrained=True,
        pretrained_model_path=None,
        # data opts
        longer_side_lengths=[224, 384],
        max_seq_len_from_original_res=-1,
    ):
        super(MUSIQ, self).__init__()

        resnet_token_dim = 64
        self.patch_size = patch_size

        self.data_preprocess_opts = {
            'patch_size': patch_size,
            'patch_stride': patch_size,
            'hse_grid_size': spatial_pos_grid_size,
            'longer_side_lengths': longer_side_lengths,
            'max_seq_len_from_original_res': max_seq_len_from_original_res,
        }

        # set num_class to 10 if pretrained model used AVA dataset
        # if not specified pretrained dataset, use AVA for default
        if pretrained_model_path is None and pretrained:
            url_key = 'ava' if isinstance(pretrained, bool) else pretrained
            num_class = 10 if url_key == 'ava' else num_class
            pretrained_model_path = default_model_urls[url_key]

        self.conv_root = StdConv(3, resnet_token_dim, 7, 2, bias=False)
        self.gn_root = nn.GroupNorm(32, resnet_token_dim, eps=1e-6)
        self.root_pool = nn.Sequential(
            nn.ReLU(True),
            ExactPadding2d(3, 2, mode='same'),
            nn.MaxPool2d(3, 2),
        )

        token_patch_size = patch_size // 4
        self.block1 = Bottleneck(resnet_token_dim, resnet_token_dim * 4)

        self.embedding = nn.Linear(
            resnet_token_dim * 4 * token_patch_size**2, hidden_size
        )
        self.transformer_encoder = TransformerEncoder(
            hidden_size,
            mlp_dim,
            attention_dropout_rate,
            dropout_rate,
            num_heads,
            num_layers,
            num_scales,
            spatial_pos_grid_size,
            use_scale_emb,
            use_sinusoid_pos_emb,
        )

        if num_class > 1:
            self.head = nn.Sequential(
                nn.Linear(hidden_size, num_class),
                nn.Softmax(dim=-1),
            )
        else:
            self.head = nn.Linear(hidden_size, num_class)

        if pretrained_model_path is not None:
            load_pretrained_network(self, pretrained_model_path, True)

    def forward(self, x, return_mos=True, return_dist=False):
        # normalize inputs to [-1, 1] as the official code
        if not self.training:
            x = (x - 0.5) * 2
            x = get_multiscale_patches(x, **self.data_preprocess_opts)

        assert len(x.shape) in [3, 4]
        if len(x.shape) == 4:
            b, num_crops, seq_len, dim = x.shape
            x = x.reshape(b * num_crops, seq_len, dim)
        else:
            b, seq_len, dim = x.shape
            num_crops = 1

        inputs_spatial_positions = x[:, :, -3]
        inputs_scale_positions = x[:, :, -2]
        inputs_masks = x[:, :, -1].bool()
        x = x[:, :, :-3]

        x = x.reshape(-1, 3, self.patch_size, self.patch_size)
        x = self.conv_root(x)
        x = self.gn_root(x)
        x = self.root_pool(x)
        x = self.block1(x)
        # to match tensorflow channel order
        x = x.permute(0, 2, 3, 1)
        x = x.reshape(b, seq_len, -1)
        x = self.embedding(x)
        x = self.transformer_encoder(
            x, inputs_spatial_positions, inputs_scale_positions, inputs_masks
        )
        q = self.head(x[:, 0])

        q = q.reshape(b, num_crops, -1)
        q = q.mean(dim=1)  # for multiple crops evaluation
        mos = dist_to_mos(q)

        return_list = []
        if return_mos:
            return_list.append(mos)
        if return_dist:
            return_list.append(q)

        if len(return_list) > 1:
            return return_list
        else:
            return return_list[0]

整体推理流程如下:多尺度的补丁获取,这在get_multiscale_patches函数中实现,get_multiscale_patches中还会生成所需要的空间的hash编码以及尺度的编码,多个patch利用一些网络结构将其变换为embedding,最后将这些crop的patch统一为token,送入transformer的编码器操作,取出第1个token,作为打分token输出。

get_multiscale_patches 函数

实际计算的代码。

def get_multiscale_patches(
    image,
    patch_size=32,
    patch_stride=32,
    hse_grid_size=10,
    longer_side_lengths=[224, 384],
    max_seq_len_from_original_res=None,
):
    # Sorting the list to ensure a deterministic encoding of the scale position.
    longer_side_lengths = sorted(longer_side_lengths)

    if len(image.shape) == 3:
        image = image.unsqueeze(0)

    n_crops, c, h, w = image.shape

    outputs = []
    for scale_id, longer_size in enumerate(longer_side_lengths):
        resized_image, rh, rw = resize_preserve_aspect_ratio(image, h, w, longer_size)

        max_seq_len = int(np.ceil(longer_size / patch_stride) ** 2)
        out = _extract_patches_and_positions_from_image(
            resized_image,
            patch_size,
            patch_stride,
            hse_grid_size,
            n_crops,
            rh,
            rw,
            c,
            scale_id,
            max_seq_len,
        )
        outputs.append(out)

    if max_seq_len_from_original_res is not None:
        out = _extract_patches_and_positions_from_image(
            image,
            patch_size,
            patch_stride,
            hse_grid_size,
            n_crops,
            h,
            w,
            c,
            len(longer_side_lengths),
            max_seq_len_from_original_res,
        )
        outputs.append(out)

    outputs = torch.cat(outputs, dim=-1)
    return outputs.transpose(1, 2)

流程如下:

  1. 首先根据scale进行等比例的resize,在resize_preserve_aspect_ratio中实现。
  2. 接着根据max_seq_len来从图像获取patch以及HSE, SCE, input mask等内容,输出的tensor大小是(n_crops, num_patches, patch_size * patch_size * c + 3)。

resize_preserve_aspect_ratio 函数

长边resize方法。

def resize_preserve_aspect_ratio(image, h, w, longer_side_length):
    # Computes the height and width after aspect-ratio-preserving resizing.
    ratio = longer_side_length / max(h, w)
    rh = round(h * ratio)
    rw = round(w * ratio)

    resized = F.interpolate(image, (rh, rw), mode='bicubic', align_corners=False)
    return resized, rh, rw

_extract_patches_and_positions_from_image 函数

核心函数模块,将图像打成patch,获取对应的hash编码、尺度编码以及mask遮蔽不需要的地方。

def _extract_patches_and_positions_from_image(
    image,
    patch_size,
    patch_stride,
    hse_grid_size,
    n_crops,
    h,
    w,
    c,
    scale_id,
    max_seq_len,
):
    n_crops, c, h, w = image.shape
    p = extract_image_patches(image, patch_size, patch_stride)
    assert p.shape[1] == c * patch_size**2

    count_h = _ceil_divide_int(h, patch_stride)
    count_w = _ceil_divide_int(w, patch_stride)

    # Shape (1, num_patches)
    spatial_p = get_hashed_spatial_pos_emb_index(hse_grid_size, count_h, count_w)
    # Shape (n_crops, 1, num_patches)
    spatial_p = spatial_p.unsqueeze(1).repeat(n_crops, 1, 1)
    scale_p = torch.ones_like(spatial_p) * scale_id
    mask_p = torch.ones_like(spatial_p)

    # Concatenating is a hacky way to pass both patches, positions and input
    # mask to the model.
    # Shape (n_crops, c * patch_size * patch_size + 3, num_patches)
    out = torch.cat([p, spatial_p.to(p), scale_p.to(p), mask_p.to(p)], dim=1)
    if max_seq_len >= 0:
        out = _pad_or_cut_to_max_seq_len(out, max_seq_len)
    return out

他的核心函数解析如下:

extract_image_patches 函数

将图像打成patch,使用的是unfold函数,这个函数相当于卷积拿数据的过程。

def extract_image_patches(x, kernel, stride=1, dilation=1):
    """
    Ref: https://stackoverflow.com/a/65886666
    """
    # Do TF 'SAME' Padding
    b, c, h, w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(
        x, (pad_col // 2, pad_col - pad_col // 2, pad_row // 2, pad_row - pad_row // 2)
    )

    # Extract patches
    patches = F.unfold(x, kernel, dilation, stride=stride)
    return patches

get_hashed_spatial_pos_emb_index 函数

获取图像的hash编码,首先生成一个grid_size的序列,然后使用近邻将其插值到实际的宽度和高度上,这样使用的值不会超出grid_size,最后将他们根据宽高组合起来就得到了最终的index。

def get_hashed_spatial_pos_emb_index(grid_size, count_h, count_w):

    pos_emb_grid = torch.arange(grid_size).float()

    pos_emb_hash_w = pos_emb_grid.reshape(1, 1, grid_size)
    pos_emb_hash_w = F.interpolate(pos_emb_hash_w, (count_w), mode='nearest')
    pos_emb_hash_w = pos_emb_hash_w.repeat(1, count_h, 1)

    pos_emb_hash_h = pos_emb_grid.reshape(1, 1, grid_size)
    pos_emb_hash_h = F.interpolate(pos_emb_hash_h, (count_h), mode='nearest')
    pos_emb_hash_h = pos_emb_hash_h.transpose(1, 2)
    pos_emb_hash_h = pos_emb_hash_h.repeat(1, 1, count_w)

    pos_emb_hash = pos_emb_hash_h * grid_size + pos_emb_hash_w

    pos_emb_hash = pos_emb_hash.reshape(1, -1)
    return pos_emb_hash

输入讲解完毕,现在回到模型的前向的网络结构部分。

class TransformerEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        mlp_dim=1152,
        attention_dropout_rate=0.0,
        dropout_rate=0,
        num_heads=6,
        num_layers=14,
        num_scales=3,
        spatial_pos_grid_size=10,
        use_scale_emb=True,
        use_sinusoid_pos_emb=False,
    ):
        super().__init__()
        self.use_scale_emb = use_scale_emb
        self.posembed_input = AddHashSpatialPositionEmbs(
            spatial_pos_grid_size, input_dim
        )
        self.scaleembed_input = AddScaleEmbs(num_scales, input_dim)

        self.cls = nn.parameter.Parameter(torch.zeros(1, 1, input_dim))
        self.dropout = nn.Dropout(dropout_rate)
        self.encoder_norm = nn.LayerNorm(input_dim, eps=1e-6)

        self.transformer = nn.ModuleDict()
        for i in range(num_layers):
            self.transformer[f'encoderblock_{i}'] = TransformerBlock(
                input_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate
            )

    def forward(
        self, x, inputs_spatial_positions, inputs_scale_positions, inputs_masks
    ):
        n, _, c = x.shape

        x = self.posembed_input(x, inputs_spatial_positions)
        if self.use_scale_emb:
            x = self.scaleembed_input(x, inputs_scale_positions)

        cls_token = self.cls.repeat(n, 1, 1)
        x = torch.cat([cls_token, x], dim=1)

        cls_mask = torch.ones((n, 1)).to(inputs_masks)
        inputs_mask = torch.cat([cls_mask, inputs_masks], dim=1)
        x = self.dropout(x)

        for k, m in self.transformer.items():
            x = m(x, inputs_mask)
        x = self.encoder_norm(x)

        return x

可以看到首先会相加hash编码和尺度编码,然后将cls_token加入到输入中,后续输入到transformer的经典结构中进行自注意力操作后输出。
与embedding相加的部分如下:

class AddHashSpatialPositionEmbs(nn.Module):
    """Adds learnable hash-based spatial embeddings to the inputs."""

    def __init__(self, spatial_pos_grid_size, dim):
        super().__init__()
        self.position_emb = nn.parameter.Parameter(
            torch.randn(1, spatial_pos_grid_size * spatial_pos_grid_size, dim)
        )
        nn.init.normal_(self.position_emb, std=0.02)

    def forward(self, inputs, inputs_positions):
        return inputs + self.position_emb.squeeze(0)[inputs_positions.long()]


class AddScaleEmbs(nn.Module):
    """Adds learnable scale embeddings to the inputs."""

    def __init__(self, num_scales, dim):
        super().__init__()
        self.scale_emb = nn.parameter.Parameter(torch.randn(num_scales, dim))
        nn.init.normal_(self.scale_emb, std=0.02)

    def forward(self, inputs, inputs_scale_positions):
        return inputs + self.scale_emb[inputs_scale_positions.long()]

Transformer模块是非常经典的MHA+FFA。

class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        mlp_dim,
        num_heads,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim, eps=1e-6)
        self.attention = MultiHeadAttention(
            dim, num_heads, bias=True, attn_drop=attn_drop
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim, eps=1e-6)
        self.mlp = Mlp(
            in_features=dim, hidden_features=mlp_dim, act_layer=act_layer, drop=drop
        )

    def forward(self, x, inputs_masks):
        y = self.norm1(x)
        y = self.attention(y, inputs_masks)
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

至此,模型前向完毕,模型后续还有一些后处理。


        x = self.transformer_encoder(
            x, inputs_spatial_positions, inputs_scale_positions, inputs_masks
        )
        q = self.head(x[:, 0])

        q = q.reshape(b, num_crops, -1)
        q = q.mean(dim=1)  # for multiple crops evaluation
        mos = dist_to_mos(q)

这里因为vit的模型设计是第1个token即cls_token代表输出,因此我们拿出第一个token作为q,q再进行reshape得到num_crops(前面多尺度送入的图像数目),这样求平均后,可以得到一个分布。
经过一下操作可以得到最终的评分。

def dist_to_mos(dist_score: torch.Tensor) -> torch.Tensor:
    """
    Convert distribution prediction to MOS score.
    For datasets with detailed score labels, such as AVA.

    Args:
        dist_score (torch.Tensor): (*, C), C is the class number.

    Returns:
        torch.Tensor: (*, 1) MOS score.
    """
    num_classes = dist_score.shape[-1]
    mos_score = dist_score * torch.arange(1, num_classes + 1).to(dist_score)
    mos_score = mos_score.sum(dim=-1, keepdim=True)
    return mos_score

此操作在讲解中有提到计算EMD损失时会使用到,一般来说num_classes可能是10(AVA数据集),这对应着打分的等级数目,也就是说模型可以得到分数的分布,这可以更好的利用数据集的标签。

3、总结

代码实现核心的部分讲解完毕,MUSIQ作为一个深度的无参考IQA指标,可以对数据进行质量或美学的评估,模型结构的设计相对合理,可以针对各种分辨率的输入给出合理的判断结果,在众多深度学习算法的客观评估中都有用到。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值