本文是对MUSIQ图像质量评价指标的代码解读,原文解读请看MUSIQ文章讲解。
本文的代码来源于IQA-Pytorch工程。
1、原文概要
该文章设计的指标在很多图像复原算法的测评中会看到,MUSIQ的特点在于其可以解决卷积神经网络(CNN)在图像质量评估(IQA)中因固定输入尺寸限制导致的图像质量失真问题。模型结构如下图所示:

MUSIQ 通过多尺度表示、空间嵌入和尺度嵌入处理可变尺寸输入,由三部分组成:
- 多尺度补丁嵌入:生成原生分辨率图像及其纵横比保持(ARP)缩放变体的补丁序列;就是最下面一行,多个不同尺寸的保持宽高比的图像序列,作者总共生成了k个,实际使用时k=3,即有3个不同大小相同宽高比的图像被打成序列输入进网络。
- 哈希基 2D 空间嵌入:编码补丁的 2D 空间位置;此模块是为了优化原有VIT的位置编码功能,由于我们进行了resize调整了绝对位置,如果还是用绝对位置编码应用到各个不同分辨率输入上显然是不合理的,此为一个hash-based的2D编码,可以解决这个问题。
- 尺度嵌入(SCE):区分不同尺度的补丁;此为了将我们resize的倍数作为信息也编码进去,因为这个也是一个重要的先验信息。
最终通过 Transformer 编码器聚合信息,输出质量评分。
2、代码结构
代码实现位于pyiqa/archs/musiq_arch.py中:

3 、核心代码模块
MUSIQ 类
这个类实现了整体的参数传入与函数调用。
@ARCH_REGISTRY.register()
class MUSIQ(nn.Module):
def __init__(
self,
patch_size=32,
num_class=1,
hidden_size=384,
mlp_dim=1152,
attention_dropout_rate=0.0,
dropout_rate=0,
num_heads=6,
num_layers=14,
num_scales=3,
spatial_pos_grid_size=10,
use_scale_emb=True,
use_sinusoid_pos_emb=False,
pretrained=True,
pretrained_model_path=None,
# data opts
longer_side_lengths=[224, 384],
max_seq_len_from_original_res=-1,
):
super(MUSIQ, self).__init__()
resnet_token_dim = 64
self.patch_size = patch_size
self.data_preprocess_opts = {
'patch_size': patch_size,
'patch_stride': patch_size,
'hse_grid_size': spatial_pos_grid_size,
'longer_side_lengths': longer_side_lengths,
'max_seq_len_from_original_res': max_seq_len_from_original_res,
}
# set num_class to 10 if pretrained model used AVA dataset
# if not specified pretrained dataset, use AVA for default
if pretrained_model_path is None and pretrained:
url_key = 'ava' if isinstance(pretrained, bool) else pretrained
num_class = 10 if url_key == 'ava' else num_class
pretrained_model_path = default_model_urls[url_key]
self.conv_root = StdConv(3, resnet_token_dim, 7, 2, bias=False)
self.gn_root = nn.GroupNorm(32, resnet_token_dim, eps=1e-6)
self.root_pool = nn.Sequential(
nn.ReLU(True),
ExactPadding2d(3, 2, mode='same'),
nn.MaxPool2d(3, 2),
)
token_patch_size = patch_size // 4
self.block1 = Bottleneck(resnet_token_dim, resnet_token_dim * 4)
self.embedding = nn.Linear(
resnet_token_dim * 4 * token_patch_size**2, hidden_size
)
self.transformer_encoder = TransformerEncoder(
hidden_size,
mlp_dim,
attention_dropout_rate,
dropout_rate,
num_heads,
num_layers,
num_scales,
spatial_pos_grid_size,
use_scale_emb,
use_sinusoid_pos_emb,
)
if num_class > 1:
self.head = nn.Sequential(
nn.Linear(hidden_size, num_class),
nn.Softmax(dim=-1),
)
else:
self.head = nn.Linear(hidden_size, num_class)
if pretrained_model_path is not None:
load_pretrained_network(self, pretrained_model_path, True)
def forward(self, x, return_mos=True, return_dist=False):
# normalize inputs to [-1, 1] as the official code
if not self.training:
x = (x - 0.5) * 2
x = get_multiscale_patches(x, **self.data_preprocess_opts)
assert len(x.shape) in [3, 4]
if len(x.shape) == 4:
b, num_crops, seq_len, dim = x.shape
x = x.reshape(b * num_crops, seq_len, dim)
else:
b, seq_len, dim = x.shape
num_crops = 1
inputs_spatial_positions = x[:, :, -3]
inputs_scale_positions = x[:, :, -2]
inputs_masks = x[:, :, -1].bool()
x = x[:, :, :-3]
x = x.reshape(-1, 3, self.patch_size, self.patch_size)
x = self.conv_root(x)
x = self.gn_root(x)
x = self.root_pool(x)
x = self.block1(x)
# to match tensorflow channel order
x = x.permute(0, 2, 3, 1)
x = x.reshape(b, seq_len, -1)
x = self.embedding(x)
x = self.transformer_encoder(
x, inputs_spatial_positions, inputs_scale_positions, inputs_masks
)
q = self.head(x[:, 0])
q = q.reshape(b, num_crops, -1)
q = q.mean(dim=1) # for multiple crops evaluation
mos = dist_to_mos(q)
return_list = []
if return_mos:
return_list.append(mos)
if return_dist:
return_list.append(q)
if len(return_list) > 1:
return return_list
else:
return return_list[0]
整体推理流程如下:多尺度的补丁获取,这在get_multiscale_patches函数中实现,get_multiscale_patches中还会生成所需要的空间的hash编码以及尺度的编码,多个patch利用一些网络结构将其变换为embedding,最后将这些crop的patch统一为token,送入transformer的编码器操作,取出第1个token,作为打分token输出。
get_multiscale_patches 函数
实际计算的代码。
def get_multiscale_patches(
image,
patch_size=32,
patch_stride=32,
hse_grid_size=10,
longer_side_lengths=[224, 384],
max_seq_len_from_original_res=None,
):
# Sorting the list to ensure a deterministic encoding of the scale position.
longer_side_lengths = sorted(longer_side_lengths)
if len(image.shape) == 3:
image = image.unsqueeze(0)
n_crops, c, h, w = image.shape
outputs = []
for scale_id, longer_size in enumerate(longer_side_lengths):
resized_image, rh, rw = resize_preserve_aspect_ratio(image, h, w, longer_size)
max_seq_len = int(np.ceil(longer_size / patch_stride) ** 2)
out = _extract_patches_and_positions_from_image(
resized_image,
patch_size,
patch_stride,
hse_grid_size,
n_crops,
rh,
rw,
c,
scale_id,
max_seq_len,
)
outputs.append(out)
if max_seq_len_from_original_res is not None:
out = _extract_patches_and_positions_from_image(
image,
patch_size,
patch_stride,
hse_grid_size,
n_crops,
h,
w,
c,
len(longer_side_lengths),
max_seq_len_from_original_res,
)
outputs.append(out)
outputs = torch.cat(outputs, dim=-1)
return outputs.transpose(1, 2)
流程如下:
- 首先根据scale进行等比例的resize,在resize_preserve_aspect_ratio中实现。
- 接着根据max_seq_len来从图像获取patch以及HSE, SCE, input mask等内容,输出的tensor大小是(n_crops, num_patches, patch_size * patch_size * c + 3)。
resize_preserve_aspect_ratio 函数
长边resize方法。
def resize_preserve_aspect_ratio(image, h, w, longer_side_length):
# Computes the height and width after aspect-ratio-preserving resizing.
ratio = longer_side_length / max(h, w)
rh = round(h * ratio)
rw = round(w * ratio)
resized = F.interpolate(image, (rh, rw), mode='bicubic', align_corners=False)
return resized, rh, rw
_extract_patches_and_positions_from_image 函数
核心函数模块,将图像打成patch,获取对应的hash编码、尺度编码以及mask遮蔽不需要的地方。
def _extract_patches_and_positions_from_image(
image,
patch_size,
patch_stride,
hse_grid_size,
n_crops,
h,
w,
c,
scale_id,
max_seq_len,
):
n_crops, c, h, w = image.shape
p = extract_image_patches(image, patch_size, patch_stride)
assert p.shape[1] == c * patch_size**2
count_h = _ceil_divide_int(h, patch_stride)
count_w = _ceil_divide_int(w, patch_stride)
# Shape (1, num_patches)
spatial_p = get_hashed_spatial_pos_emb_index(hse_grid_size, count_h, count_w)
# Shape (n_crops, 1, num_patches)
spatial_p = spatial_p.unsqueeze(1).repeat(n_crops, 1, 1)
scale_p = torch.ones_like(spatial_p) * scale_id
mask_p = torch.ones_like(spatial_p)
# Concatenating is a hacky way to pass both patches, positions and input
# mask to the model.
# Shape (n_crops, c * patch_size * patch_size + 3, num_patches)
out = torch.cat([p, spatial_p.to(p), scale_p.to(p), mask_p.to(p)], dim=1)
if max_seq_len >= 0:
out = _pad_or_cut_to_max_seq_len(out, max_seq_len)
return out
他的核心函数解析如下:
extract_image_patches 函数
将图像打成patch,使用的是unfold函数,这个函数相当于卷积拿数据的过程。
def extract_image_patches(x, kernel, stride=1, dilation=1):
"""
Ref: https://stackoverflow.com/a/65886666
"""
# Do TF 'SAME' Padding
b, c, h, w = x.shape
h2 = math.ceil(h / stride)
w2 = math.ceil(w / stride)
pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
x = F.pad(
x, (pad_col // 2, pad_col - pad_col // 2, pad_row // 2, pad_row - pad_row // 2)
)
# Extract patches
patches = F.unfold(x, kernel, dilation, stride=stride)
return patches
get_hashed_spatial_pos_emb_index 函数
获取图像的hash编码,首先生成一个grid_size的序列,然后使用近邻将其插值到实际的宽度和高度上,这样使用的值不会超出grid_size,最后将他们根据宽高组合起来就得到了最终的index。
def get_hashed_spatial_pos_emb_index(grid_size, count_h, count_w):
pos_emb_grid = torch.arange(grid_size).float()
pos_emb_hash_w = pos_emb_grid.reshape(1, 1, grid_size)
pos_emb_hash_w = F.interpolate(pos_emb_hash_w, (count_w), mode='nearest')
pos_emb_hash_w = pos_emb_hash_w.repeat(1, count_h, 1)
pos_emb_hash_h = pos_emb_grid.reshape(1, 1, grid_size)
pos_emb_hash_h = F.interpolate(pos_emb_hash_h, (count_h), mode='nearest')
pos_emb_hash_h = pos_emb_hash_h.transpose(1, 2)
pos_emb_hash_h = pos_emb_hash_h.repeat(1, 1, count_w)
pos_emb_hash = pos_emb_hash_h * grid_size + pos_emb_hash_w
pos_emb_hash = pos_emb_hash.reshape(1, -1)
return pos_emb_hash
输入讲解完毕,现在回到模型的前向的网络结构部分。
class TransformerEncoder(nn.Module):
def __init__(
self,
input_dim,
mlp_dim=1152,
attention_dropout_rate=0.0,
dropout_rate=0,
num_heads=6,
num_layers=14,
num_scales=3,
spatial_pos_grid_size=10,
use_scale_emb=True,
use_sinusoid_pos_emb=False,
):
super().__init__()
self.use_scale_emb = use_scale_emb
self.posembed_input = AddHashSpatialPositionEmbs(
spatial_pos_grid_size, input_dim
)
self.scaleembed_input = AddScaleEmbs(num_scales, input_dim)
self.cls = nn.parameter.Parameter(torch.zeros(1, 1, input_dim))
self.dropout = nn.Dropout(dropout_rate)
self.encoder_norm = nn.LayerNorm(input_dim, eps=1e-6)
self.transformer = nn.ModuleDict()
for i in range(num_layers):
self.transformer[f'encoderblock_{i}'] = TransformerBlock(
input_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate
)
def forward(
self, x, inputs_spatial_positions, inputs_scale_positions, inputs_masks
):
n, _, c = x.shape
x = self.posembed_input(x, inputs_spatial_positions)
if self.use_scale_emb:
x = self.scaleembed_input(x, inputs_scale_positions)
cls_token = self.cls.repeat(n, 1, 1)
x = torch.cat([cls_token, x], dim=1)
cls_mask = torch.ones((n, 1)).to(inputs_masks)
inputs_mask = torch.cat([cls_mask, inputs_masks], dim=1)
x = self.dropout(x)
for k, m in self.transformer.items():
x = m(x, inputs_mask)
x = self.encoder_norm(x)
return x
可以看到首先会相加hash编码和尺度编码,然后将cls_token加入到输入中,后续输入到transformer的经典结构中进行自注意力操作后输出。
与embedding相加的部分如下:
class AddHashSpatialPositionEmbs(nn.Module):
"""Adds learnable hash-based spatial embeddings to the inputs."""
def __init__(self, spatial_pos_grid_size, dim):
super().__init__()
self.position_emb = nn.parameter.Parameter(
torch.randn(1, spatial_pos_grid_size * spatial_pos_grid_size, dim)
)
nn.init.normal_(self.position_emb, std=0.02)
def forward(self, inputs, inputs_positions):
return inputs + self.position_emb.squeeze(0)[inputs_positions.long()]
class AddScaleEmbs(nn.Module):
"""Adds learnable scale embeddings to the inputs."""
def __init__(self, num_scales, dim):
super().__init__()
self.scale_emb = nn.parameter.Parameter(torch.randn(num_scales, dim))
nn.init.normal_(self.scale_emb, std=0.02)
def forward(self, inputs, inputs_scale_positions):
return inputs + self.scale_emb[inputs_scale_positions.long()]
Transformer模块是非常经典的MHA+FFA。
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
mlp_dim,
num_heads,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim, eps=1e-6)
self.attention = MultiHeadAttention(
dim, num_heads, bias=True, attn_drop=attn_drop
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim, eps=1e-6)
self.mlp = Mlp(
in_features=dim, hidden_features=mlp_dim, act_layer=act_layer, drop=drop
)
def forward(self, x, inputs_masks):
y = self.norm1(x)
y = self.attention(y, inputs_masks)
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
至此,模型前向完毕,模型后续还有一些后处理。
x = self.transformer_encoder(
x, inputs_spatial_positions, inputs_scale_positions, inputs_masks
)
q = self.head(x[:, 0])
q = q.reshape(b, num_crops, -1)
q = q.mean(dim=1) # for multiple crops evaluation
mos = dist_to_mos(q)
这里因为vit的模型设计是第1个token即cls_token代表输出,因此我们拿出第一个token作为q,q再进行reshape得到num_crops(前面多尺度送入的图像数目),这样求平均后,可以得到一个分布。
经过一下操作可以得到最终的评分。
def dist_to_mos(dist_score: torch.Tensor) -> torch.Tensor:
"""
Convert distribution prediction to MOS score.
For datasets with detailed score labels, such as AVA.
Args:
dist_score (torch.Tensor): (*, C), C is the class number.
Returns:
torch.Tensor: (*, 1) MOS score.
"""
num_classes = dist_score.shape[-1]
mos_score = dist_score * torch.arange(1, num_classes + 1).to(dist_score)
mos_score = mos_score.sum(dim=-1, keepdim=True)
return mos_score
此操作在讲解中有提到计算EMD损失时会使用到,一般来说num_classes可能是10(AVA数据集),这对应着打分的等级数目,也就是说模型可以得到分数的分布,这可以更好的利用数据集的标签。
3、总结
代码实现核心的部分讲解完毕,MUSIQ作为一个深度的无参考IQA指标,可以对数据进行质量或美学的评估,模型结构的设计相对合理,可以针对各种分辨率的输入给出合理的判断结果,在众多深度学习算法的客观评估中都有用到。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

被折叠的 条评论
为什么被折叠?



