Segment Anything Model代码讲解(一)之SAM

​​在这里插入图片描述

SAM代码内容解析

导入依赖包

import torch
from torch import nn
from torch.nn import functional as F

from typing import Any, Dict, List, Tuple

from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder

def postprocess_masks(
        self,
        masks: torch.Tensor,      #掩码
        input_size: Tuple[int, ...],        #输入
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        """
        Remove padding and upscale masks to the original image size.
        将填充和高清掩膜解码的过程去除,使掩膜恢复到原始图像的大小。
        Arguments:
          masks (torch.Tensor): 从掩码解码器得到的批处理掩码,格式为BxCxHxW
          input_size (tuple(int, int)): 输入到模型的图像大小,格式为(H,W),用于去除填充
          original_size (tuple(int, int)): 输入模型之前的原始图像大小,格式为(H,W)
        Returns:
          (torch.Tensor):BxCxHxW格式的批量掩码,其中(H,W)由original_size给出.
        """
        masks = F.interpolate(
            masks,
            (self.image_encoder.img_size, self.image_encoder.img_size),
            mode="bilinear",
            align_corners=False,
        ) #线性插值算法进行掩码
        masks = masks[..., : input_size[0], : input_size[1]]
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks
#定义预处理方法
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize pixel values and pad to a square input."""
        # Normalize colors 归一化
        x = (x - self.pixel_mean) / self.pixel_std

        # Pad 填充到self.image_encoder.img_size的正方形
        h, w = x.shape[-2:]
        padh = self.image_encoder.img_size - h
        padw = self.image_encoder.img_size - w
        x = F.pad(x, (0, padw, 0, padh))
        return x
class Sam(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"

    def __init__(
        self,
        image_encoder: ImageEncoderViT,  #传入一个ImageEncoderViT类型的参数
        prompt_encoder: PromptEncoder,   #传入一个PromptEncoder类型的参数
        mask_decoder: MaskDecoder,       #传入一个MaskDecoder类型的参数
        pixel_mean: List[float] = [123.675, 116.28, 103.53],  #对输入图像中的像素进行归一化的平均值。
        pixel_std: List[float] = [58.395, 57.12, 57.375],     #用于对输入图像中的像素进行标准化的标准值。
    ) -> None:
        """
        SAM predicts object masks from an image and input prompts.
       """
        super().__init__()
        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder
        
        #self.register_buffer用于存储像素值的平均值,它被转换为torch.Tensor并传递给这个方法,以供模型在处理图像时使用。
        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

    #@property装饰器让使用方法访问属性值的代码看起来像访问实例属性一样,同时为属性访问提供更多的控制和保护。
    @property
    def device(self) -> Any:
        return self.pixel_mean.device

    #torch.no_grad()装饰器主要用来禁用PyTorch中的梯度计算,以降低内存消耗和提高代码效率        
    @torch.no_grad()
    def forward(
        self,
        batched_input: List[Dict[str, Any]],
        multimask_output: bool,
    ) -> List[Dict[str, torch.Tensor]]:
        """
        Predicts masks end-to-end from provided images and prompts.
        If prompts are not known in advance, using SamPredictor is
        recommended over calling the model directly.
        Arguments:
          batched_input (list(dict)): 一个输入图像的列表,每个图像为一个字典,具有以下键。        如果不存在,可以省略提示键。.
              'image':  3xHxW格式的torch张量形式的图像,已经转换为输入到模型中。
              'original_size': (tuple(int,int))转换前图像的原始大小,格式为(H,W)
              'point_coords':(torch.Tensor)此图像的批处理点提示,形状为BxNx2。已转换为模型的输入帧
              'point_labels': (torch.Tensor)点提示的标签,形状为BxN的批量标签
              'boxes': torch.Tensor)批量框输入,形状为Bx4。已转换到模型的输入box
              'mask_inputs': 模型输入的批量掩码输入,形式为Bx1xHxW
          	  'multimask_output (bool)': 模型是否应预测多个消除歧义的掩码,还是返回单个掩码
        Returns:
         	  '(list(dict))': 包含每个图像的字典列表,其中每个元素具有以下键
              'masks': (torch.Tensor) 批量二进制掩码预测,形状为BxCxHxW,其中B是输入提示的数量,C由multimask_output确定,(H,W)是原始图像的大小
              'iou_predictions': torch.Tensor)掩码质量的模型预测,形状为BxC.
              'low_res_logits': (torch.Tensor) 低分辨率的logits,形状为BxCxHxW,其中H=W=256。可以将其作为掩码输入传递给后续的预测迭代
        """
        #将图片进行批处理后放入对象input_images的栈中(torch.stack)
        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
        
        #怼图片进行image_encode形式的编码嵌入
        image_embeddings = self.image_encoder(input_images)
        
        #定义局部变量outputs作为输出
        outputs = []
        for image_record, curr_embedding in zip(batched_input, image_embeddings):
            if "point_coords" in image_record:
                points = (image_record["point_coords"], image_record["point_labels"])
            else:
                points = None
            
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=points,
                boxes=image_record.get("boxes", None),
                masks=image_record.get("mask_inputs", None),
            )
            
            #解码的结果
            low_res_masks, iou_predictions = self.mask_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
            
			#对mask进行处理得到原始图像大小尺度
            masks = self.postprocess_masks(
                low_res_masks,
                input_size=image_record["image"].shape[-2:],
                original_size=image_record["original_size"],
            )
            masks = masks > self.mask_threshold
            
            #将预测出的mask、iou_predictions、low_res_logits的结果放入output中
            outputs.append(
                {
                    "masks": masks,
                    "iou_predictions": iou_predictions,
                    "low_res_logits": low_res_masks,
                }
            )
        return outputs
### 使用 Segment Anything Model (SAM) 实现摇杆图像的语义分割 #### SAM 的核心功能概述 Segment Anything Model (SAM)种强大的通用分割模型,其设计目标是实现“切的分割”,即它可以处理任意类型的对象并对其进行精确分割[^1]。相比于传统的语义分割方法仅关注预定义的对象类别,或者实例分割专注于区分不同实例,SAM 提供了种更灵活的方式来进行分割。 当应用于摇杆图像时,SAM 不需要事先知道摇杆的具体形状或类别即可完成分割任务。这得益于它的提示机制(prompt-based mechanism),允许用户通过简单输入(如点、边界框或其他形式的提示)引导模型聚焦于感兴趣的区域。 --- #### 数据准备与环境搭建 为了使用 SAM 进行摇杆图像的语义分割,需先准备好必要的数据和工具: 1. **安装依赖库** 需要安装 PyTorch 和 SAM 官方提供的 Python 库 `segment_anything`。以下是安装命令: ```bash pip install torch torchvision pip install git+https://github.com/facebookresearch/segment-anything.git ``` 2. **下载模型权重文件** 下载官方发布的 SAM 模型权重文件(例如 `sam_vit_h_4b8939.pth` 或其他变体)。可以从 Facebook AI Research 的 GitHub 页面获取这些资源。 3. **加载模型** 加载 SAM 模型及其配置参数可以通过以下代码片段实现: ```python from segment_anything import sam_model_registry, SamPredictor model_type = "vit_h" # 可选:vit_b, vit_l, vit_h checkpoint_path = "path/to/sam_vit_h_4b8939.pth" device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry[model_type](checkpoint=checkpoint_path).to(device) predictor = SamPredictor(sam) ``` --- #### 输入提示与预测过程 对于摇杆图像的分割任务,可以采用如下方式设置提示并向 SAM 提供输入信息: 1. **读取图像** 将待分割的摇杆图像加载到内存中,并传递给 SAM 的预测器。 ```python image = cv2.imread("joystick_image.jpg") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image) ``` 2. **指定提示类型** 用户可以根据需求选择不同的提示模式,比如单击感兴趣区域内的某个位置作为点提示,或者绘制个粗略包围摇杆的矩形框作为边界框提示。下面分别展示两种常见情况下的代码示例: - **基于点提示** 如果已知摇杆的大致中心位置,则可以直接向模型提供该点坐标。 ```python input_point = np.array([[x_center, y_center]]) # 替换为实际像素坐标 input_label = np.array([1]) # 表示此点属于前景 masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False ) ``` - **基于边界框提示** 当希望限定搜索范围至某特定区域内时,可利用边界框代替点提示。 ```python bbox = np.array([xmin, ymin, xmax, ymax]) # 替换为真实值 mask, _, _ = predictor.predict(box=bbox, multimask_output=False) ``` 3. **可视化结果** 得到掩码后,可通过叠加显示原始图片与分割后的效果以便验证准确性。 ```python plt.figure(figsize=(10,10)) plt.imshow(image) show_mask(masks[0], plt.gca(), random_color=True) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.show() ``` 上述过程中提到的功能函数(如 `show_mask`, `show_points` 等辅助绘图工具)通常由开发者自行编写或参考官方文档中的实现版本。 --- #### 结果优化建议 尽管 SAM 已具备较强的泛化能力,但在某些复杂场景下仍可能遇到挑战。此时可以考虑调整以下几个方面来提升性能表现: - 增加更多样化的提示组合; - 对低质量输出重新运行迭代改进流程; - 利用额外训练样本微调基础网络结构以适应具体领域特性。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

果粒橙_LGC

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值