Segment Anything模型架构深度解析:重新定义图像分割的新范式
你是否还在为复杂的图像分割任务烦恼?是否希望有一种模型能够通过简单的点选或框选就能精准分割出图像中的任意物体?Segment Anything模型(SAM)的出现,彻底改变了传统图像分割的工作流程。本文将深入解析SAM的核心架构,带你了解它如何实现"点哪儿分哪儿"的神奇功能,以及背后的技术创新。读完本文,你将能够:
- 理解SAM模型的三大核心组件及其协作方式
- 掌握图像从输入到生成掩码的完整流程
- 了解SAM如何处理不同类型的用户提示
- 认识SAM在实际应用中的表现和优势
SAM整体架构概览
Segment Anything模型采用了模块化设计,主要由三个核心组件构成:图像编码器(Image Encoder)、提示编码器(Prompt Encoder)和掩码解码器(Mask Decoder)。这种架构设计使得模型能够灵活处理多种输入提示,并高效生成高质量的分割掩码。
SAM的核心思想是将图像编码为通用的特征表示,然后根据用户提供的提示(如点、框或掩码),通过提示编码器生成相应的特征,最后由掩码解码器结合这两种特征生成精确的分割掩码。整个流程由segment_anything/modeling/sam.py中的Sam类统一协调。
核心组件协作流程
- 图像编码器将输入图像转换为高维特征图
- 提示编码器处理用户输入的各种提示,生成提示特征
- 掩码解码器结合图像特征和提示特征,生成最终的分割掩码和质量分数
下面我们将逐一解析这三个核心组件的内部结构和工作原理。
图像编码器:视觉特征的提取器
图像编码器是SAM的"眼睛",负责将原始图像转换为具有丰富语义信息的特征表示。SAM采用了基于Vision Transformer(ViT)的架构,并进行了针对性优化,以适应图像分割任务的需求。
ViT架构的创新应用
SAM的图像编码器实现于segment_anything/modeling/image_encoder.py中的ImageEncoderViT类。它将输入图像分割为16x16的 patches,通过补丁嵌入(Patch Embedding)将每个patch转换为向量,并添加位置嵌入(Positional Embedding)以保留空间信息。
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
# ... 其他参数
) -> None:
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
# ... 位置嵌入和Transformer块初始化
混合注意力机制
为了在保持计算效率的同时捕获长距离依赖关系,SAM的图像编码器采用了混合注意力机制:大部分Transformer块使用窗口注意力(Window Attention)以提高计算效率,而少数块使用全局注意力(Global Attention)以捕获全局上下文信息。
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
# ... 其他参数
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
特征调整颈部网络
在Transformer编码器之后,SAM添加了一个颈部网络(Neck),用于将Transformer输出的特征调整为适合后续掩码解码器处理的维度(默认256维):
self.neck = nn.Sequential(
nn.Conv2d(embed_dim, out_chans, kernel_size=1, bias=False),
LayerNorm2d(out_chans),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
LayerNorm2d(out_chans),
)
提示编码器:理解用户意图
提示编码器是SAM的"耳朵",负责将用户提供的各种提示转换为模型能够理解的特征表示。SAM支持多种提示类型,包括点、框和掩码,这些都在segment_anything/modeling/prompt_encoder.py中实现。
点和框提示编码
对于点和框提示,SAM使用位置编码(Positional Encoding)将空间坐标转换为特征向量。点提示分为正点(目标内部)和负点(目标外部),框提示则由其两个对角顶点表示:
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
points = points + 0.5 # Shift to center of pixel
if pad:
# 添加填充点以确保至少有一个点
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
# 使用随机位置编码
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
# 根据标签添加不同的嵌入
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight # 负点
point_embedding[labels == 1] += self.point_embeddings[1].weight # 正点
return point_embedding
掩码提示编码
对于掩码提示,SAM使用卷积神经网络将输入掩码压缩为低维特征:
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
位置编码的创新
SAM采用了随机位置编码(Random Position Encoding)而非传统的正弦余弦位置编码,这有助于提高模型的泛化能力:
class PositionEmbeddingRandom(nn.Module):
"""使用随机空间频率的位置编码"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
coords = 2 * coords - 1 # 归一化到[-1, 1]
coords = coords @ self.positional_encoding_gaussian_matrix # 投影到随机矩阵
coords = 2 * np.pi * coords # 缩放
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) # 生成正弦余弦编码
掩码解码器:生成精确分割
掩码解码器是SAM的"手",负责结合图像特征和提示特征,生成最终的分割掩码。这一过程在segment_anything/modeling/mask_decoder.py中实现。
Transformer解码器
掩码解码器使用一个小型Transformer来处理图像特征和提示特征:
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
# ... 其他参数
) -> None:
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer # 用于处理图像和提示特征的Transformer
# ...
动态掩码生成
掩码解码器的核心创新在于引入了动态掩码生成机制。模型会预测多个候选掩码,并为每个掩码生成一个质量分数,供用户选择或自动选取最优结果:
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 预测掩码和质量分数
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# 根据是否需要多掩码输出选择不同的掩码切片
if multimask_output:
mask_slice = slice(1, None) # 多掩码输出(3个结果)
else:
mask_slice = slice(0, 1) # 单掩码输出(最佳结果)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
return masks, iou_pred
掩码上采样
由于Transformer输出的特征图分辨率较低,掩码解码器使用转置卷积(Transposed Convolution)进行上采样,以生成高分辨率的掩码:
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
实际应用示例
SAM的强大功能可以通过项目提供的示例笔记本直观地展示。下面我们将介绍几个典型的应用场景。
自动掩码生成
对于没有特定提示的情况,SAM可以自动生成图像中所有物体的分割掩码。这一功能在notebooks/automatic_mask_generator_example.ipynb中展示:
自动掩码生成器会分析图像内容,为每个检测到的物体生成一个分割掩码,这对于图像内容分析、物体计数等任务非常有用。
交互式分割
SAM最强大的功能是交互式分割,用户只需提供少量提示(如几个点),模型就能快速分割出目标物体。notebooks/predictor_example.ipynb展示了这一功能:
通过简单的点选,SAM就能精准分割出图像中的狗。用户还可以通过添加更多提示来细化分割结果,实现高精度的交互式分割。
ONNX模型部署
为了方便在实际应用中部署,SAM提供了将模型导出为ONNX格式的功能。scripts/export_onnx_model.py脚本可以将模型导出为ONNX格式,以便在各种平台上高效运行。导出的ONNX模型可以通过notebooks/onnx_model_example.ipynb进行验证和使用。
结语
Segment Anything模型通过创新的架构设计,实现了图像分割任务的重大突破。其三大核心组件——图像编码器、提示编码器和掩码解码器——协同工作,使得模型能够灵活处理多种提示类型,并生成高质量的分割掩码。
SAM的出现不仅简化了图像分割的工作流程,还为许多下游任务如图像编辑、目标检测、语义分析等提供了强大的基础工具。随着技术的不断发展,我们有理由相信,SAM将在计算机视觉领域发挥越来越重要的作用。
如果你对SAM的实现细节感兴趣,可以通过查阅源代码进一步深入学习。项目的核心代码位于segment_anything/目录下,包含了模型实现的所有细节。此外,项目提供的示例笔记本也是学习如何使用SAM的绝佳资源。
通过本文的介绍,希望你对Segment Anything模型的架构有了深入的理解,能够更好地利用这一强大工具解决实际问题。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考






