前言
按照计划对MaskFormer进行了复现,配置了模型所需的环境,在实验室服务器上跑通了训练脚本。
对源码中模型搭建的部分进行学习,结合原论文深入理解整体结构。
1. 模型概述
整个代码结构基于detectron2框架,所以会有很多注册的指令和from_config()函数,这两个都不影响代码的逻辑,在看源码的时候不必纠结,把所有的from_config()都看成从配置文件读取相关变量的值即可,具体的值可以在config/xx/xx.yaml文件种找到;注册指令是为了detectron2可以检测到,看源码的时候可以直接忽略这条指令。
总共四个核心类分别是:
- mask_former/mask_former_model.py中的MaskFormer
- mask_former/modeling/heads/mask_former_head.py中的MaskFormerHead
- mask_former/modeling/heads/pixel_decoder.py中的TransformerEncoderPixelDecoder
- mask_former/modeling/transformer/transformer_predictor.py中的TransformerPredictor
@META_ARCH_REGISTRY.register()
class MaskFormer(nn.Module):
"""
Main class for mask classification semantic segmentation architectures.
"""
@configurable
def __init__(
self,
*,
backbone: Backbone,
sem_seg_head: nn.Module,
criterion: nn.Module,
num_queries: int,
panoptic_on: bool,
object_mask_threshold: float,
overlap_threshold: float,
metadata,
size_divisibility: int,
sem_seg_postprocess_before_inference: bool,
pixel_mean: Tuple[float],
pixel_std: Tuple[float],
):
其中,MaskFormer类是模型的入口,整体结构都在这个类中。在这个类的构造函数中可以看到模型是由backbone、sem_seg_head、criterion构成,通过debug,可以打印出这三个参数所对应的模块。其中backbone对应下图中的backbone,sem_seg_head包含了除backbone外的所有部分,criterion是损失函数的类,对应下图中的两个Loss。

2. backbone
作者提供了resnet和swin-transformer两种backbone,输出outs[“res{}”.format(i + 2)] = out,是一个字典,关键字是res2、res3等。
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = {}
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs["res{}".format(i + 2)] = out
return outs
3. MaskFormerHead
整个模型最核心的部分都在MaskFormerHead中,提供了结构图中Decoder的功能,直接看forward()函数
def forward(self, features):
return self.layers(features)
def layers(self, features):
mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
if self.transformer_in_feature == "transformer_encoder":
assert (
transformer_encoder_features is not None
), "Please use the TransformerEncoderPixelDecoder."
predictions = self.predictor(transformer_encoder_features, mask_features)
else:
predictions = self.predictor(features[self.transformer_in_feature], mask_features)
return predictions
主要是predictor类根据pixel_decoder类的结果进行处理后直接输出,那么核心的类就变成了pixel_decoder和predictor。
3.1 pixel_decoder
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
return {
"input_shape": {
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
},
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
"transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
"transformer_predictor": TransformerPredictor(
cfg,
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
mask_classification=True,
),
}
从mask_former_head.py中的from_config()中可以看到pixel_decoder来自于build_pixel_decoder(),从build_pixel_decoder中可以看到pixel_decoder是根据配置文件中的cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME指定的,通过debug,此处参数为name = BasePixelDecoder。
def build_pixel_decoder(cfg, input_shape):
"""
Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
"""
name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
forward_features = getattr(model, "forward_features", None)
if not callable(forward_features):
raise ValueError(
"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
f"Please implement forward_features for {name} to only return mask features."
)
return model
BasePixelDecoder在mask_former/heads/pixel_decoder.py文件中,TransformerEncoderPixelDecoder继承自BasePixelDecoder,这个类就是我们要的pixel_decoder,在TransformerEncoderPixelDecoder的forward_features()函数中返回了两个值self.mask_features(y), transformer_encoder_features正好对应着layers()中的mask_features和transformer_encoder_features。
def forward_features(self, features):
# Reverse feature maps into top-down order (from low to high resolution)
for idx, f in enumerate(self.in_features[::-1]):
x = features[f]
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
if lateral_conv is None:
transformer = self.input_proj(x)
pos = self.pe_layer(x)
transformer = self.transformer(transformer, None, pos)
y = output_conv(transformer)
# save intermediate feature as input to Transformer decoder
transformer_encoder_features = transformer
else:
cur_fpn = lateral_conv(x)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
y = output_conv(y)
return self.mask_features(y), transformer_encoder_features
self.mask_features()和计算y的lateral_conv都来源于基类BasePixelDecoder。通过BasePixelDecoder的forward_features()函数可以发现其实就是一个常规的FPN结构,把backbone提取的特征逐层加到一起,如果是最后一层,则使用output_conv()输出。
需要注意的是,当取到 ‘res5’ 这个特征层时,lateral_convs.append(None)。因为’res5’这个特征是要经过一个TransformerEncoder。
所以,TransformerEncoderPixelDecoder返回的第一个参数self.mask_features(y)就是FPN的结果。一般在做语义分割的时候就直接用这个结果去计算损失和预测了,但是这里这个结果仅仅作为mask的feature,后面还要和Transformer输出的mask_embed做矩阵乘法得到最终的pred_masks。
TransformerEncoderPixelDecoder返回的第二个参数transformer_encoder_features是由self.transformer()生成的。其实这个self.transformer就是对特征层 ‘res5’ 做了一个标准的self-attention计算每个像素的attention。
if lateral_conv is None:
transformer = self.input_proj(x) # 将res5的通道数变换到conv_dim
pos = self.pe_layer(x)
transformer = self.transformer(transformer, None, pos) # pos位置Embedding
y = output_conv(transformer)
# save intermediate feature as input to Transformer decoder
transformer_encoder_features = transformer
3.2 predictor
还是看MaskFormerHead的from_config()函数,可以看到
"transformer_predictor": TransformerPredictor()
从而需要的predictor就是mask_former/modeling/transformer/transformer_predictor.py中的TransformerPredictor这个类。从forward()看起。
def forward(self, x, mask_features): # x=transformer_encoder_features,mask_features=mask_features
pos = self.pe_layer(x)
src = x
mask = None
# 输出的两个值hs和memory分别是Transformer的输出和Transformer中Encoder的输出
hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
if self.mask_classification:
outputs_class = self.class_embed(hs) # 类别预测分支 输出[b, q, num_classes + 1]
out = {"pred_logits": outputs_class[-1]}
else:
out = {}
# 计算mask的输出
if self.aux_loss: # 是否利用Transformer.Decoder的各层输出来计算损失,
# [l, bs, queries, embed] l是Transformer.Decoder的层数
mask_embed = self.mask_embed(hs) # 将Transformer的输出映射成和mask_features的维度一致
# 和mask_features做一次矩阵运算得到最终的mask分支输出 mask的形状是[b, q, h, w]
outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
out["pred_masks"] = outputs_seg_masks[-1] # 只要了最后一层
out["aux_outputs"] = self._set_aux_loss(
outputs_class if self.mask_classification else None, outputs_seg_masks
)
else: # 只使用Transformer的最终输出来计算损失
# FIXME h_boxes takes the last one computed, keep this in mind
# [bs, queries, embed]
mask_embed = self.mask_embed(hs[-1])
outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
out["pred_masks"] = outputs_seg_masks
return out
可以看到mask的形状是[b, q, h, w],源码中的L是Transformer.Decoder的层数。
我们在做语义分割的时候,通常做法都是给每张特征图上的像素一个类,所以输出的格式是[b, num_class, h, w],这样输出在语义分割任务中没什么问题,在实例分割任务中如果一张图片中两个同类别的实例有重叠部分就很难分开了。
换个思路,输出的特征图变成[b, q, h, w]格式,这里的q是Query的个数,每张特征图都是一个二值化的mask,具体这个mask的类别由另外一个预测类别的分支来确定,预测类别分支的输出格式就是[b, q, num_class]。
总结一下,把之前每个像素点一个类别的输出修改成两个分支,一个分支用于预测类别,输出格式是[b, q, num_class];一个分支用于预测一张图像上各个实例的mask,输出格式是[b, q, h, w],其中q代表查询的个数,或者理解成一张图像上最多预测的实例个数。(源码中num_queries = 100,stuff_classes = 150)
4. 损失函数
论文的创新之处在于提供了一种把语义分割和实例分割统一起来的思路,作者认为统一两种任务的关键在于损失函数的计算和网络输出的后续处理。
4.1 匈牙利算法
损失函数的计算在mask_former/modeling/criterion.py文件中。根据第3节的描述,输出其实是[b, q, h, w]格式,q个[h, w]特征图中并不一定都有输出,那怎么计算呢?与DETR相同采用匈牙利算法解决这一问题。
有q个特征图和i个目标(一张图像上有i个目标,每张图的i不固定),我们需要在q个特征图中找到对应的i个目标,使得总的cost最小,也就是源码中mask_former/modeling/matcher.py文件完成的任务。
def memory_efficient_forward(self, outputs, targets):
"""More memory-friendly matching"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# Work out the mask padding size
masks = [v["masks"] for v in targets]
h_max = max([m.shape[1] for m in masks])
w_max = max([m.shape[2] for m in masks])
indices = []
# Iterate through batch size
for b in range(bs):
out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
# tgt_ids这个变量的值里面存的是每一张图像中出现的类别序号,比如图像中有人、车、飞机,对应的class_id是1,3,7,则tgt_ids=[1,3,7]
tgt_ids = targets[b]["labels"]
# gt masks are already padded when preparing target
tgt_mask = targets[b]["masks"].to(out_mask)
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -out_prob[:, tgt_ids]
# Downsample gt masks to save memory
tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest")
# Flatten spatial dimension
out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W]
tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W]
# Compute the focal loss between masks
cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)
# Compute the dice loss betwen masks
cost_dice = batch_dice_loss(out_mask, tgt_mask)
# Final cost matrix
C = (
self.cost_mask * cost_mask
+ self.cost_class * cost_class
+ self.cost_dice * cost_dice
)
C = C.reshape(num_queries, -1).cpu()
indices.append(linear_sum_assignment(C))
return [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices
]
源码中实际的匈牙利算法的实现是直接调用了scipy库中的linear_sum_assignment(),在调用linear_sum_assignment()之前都是为了生成花费矩阵C,最终送到linear_sum_assignment()函数的C的形状是[num_queries, num_total_targets],函数返回两个序列,第一个序列对应的是行的序号,第二个序列是对应的列序号。
实际中num_queries不会刚好等于num_total_targets,即输入的不是一个方阵,那么输出是什么?当不是方阵的时候,输出的第一个和第二个值的长度等于min(num_queries, num_total_targets),作者在match的forward函数有说明,已经很清楚了。
4.2 计算损失
经过上面的匈牙利匹配算法,我们在输出的query中匹配到了最相似的类别,接下来就是使用得到的query计算损失。
回到criterion,分别计算了类别损失和mask的损失。
其中mask loss包含了dice loss和每个像素值的focal_loss,只是这里计算focal_loss是需要对每个像素做sigmoid操作的,而不是我们平时做的softmax,因为我们的类别信息是由类别分支提供的,图像上的像素点仅仅提供是否前景还是背景信息,所以用sigmoid函数。
类别损失就是正常的CE loss。
def forward(self, outputs, targets):
"""This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_masks = sum(len(t["labels"]) for t in targets)
num_masks = torch.as_tensor(
[num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_masks)
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if "aux_outputs" in outputs:
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
return losses