sam2
这个项目的实时推理实现要点如下:
项目实时推理实现分析
这是一个基于 SAM2 (Segment Anything Model 2) 的实时视频分割项目,实现了对实时视频流的在线分割与跟踪。
1. 核心架构设计
项目采用 SAM2CameraPredictor 类来处理实时视频流:
class SAM2CameraPredictor(SAM2Base):
"""The predictor class to handle user interactions and manage inference states."""
def __init__(
self,
fill_hole_area=0,
# whether to apply non-overlapping constraints on the output object masks
non_overlap_masks=False,
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False,
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
clear_non_cond_mem_for_multi_obj=False,
**kwargs,
):
super().__init__(**kwargs)
self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
self.condition_state = {}
self.frame_idx = 0
2. 实时推理流程
2.1 初始化阶段(第一帧)
@torch.inference_mode()
def load_first_frame(self, img):
self.condition_state = self._init_state(
offload_video_to_cpu=False, offload_state_to_cpu=False
)
img, width, height = self.perpare_data(img, image_size=self.image_size)
self.condition_state["images"] = [img]
self.condition_state["num_frames"] = len(self.condition_state["images"])
self.condition_state["video_height"] = height
self.condition_state["video_width"] = width
self._get_image_feature(frame_idx=0, batch_size=1)
- 加载并预处理第一帧图像
- 初始化推理状态(condition_state)
- 提取第一帧的图像特征
2.2 跟踪阶段(后续帧)
@torch.inference_mode()
def track(
self,
img,
):
self.frame_idx += 1
self.condition_state["num_frames"] += 1
if not self.condition_state["tracking_has_started"]:
self.propagate_in_video_preflight()
img, _, _ = self.perpare_data(img, image_size=self.image_size)
output_dict = self.condition_state["output_dict"]
obj_ids = self.condition_state["obj_ids"]
batch_size = self._get_obj_num()
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_feature(img, batch_size)
current_out = self.track_step(
frame_idx=self.frame_idx,
is_init_cond_frame=False,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=None,
mask_inputs=None,
output_dict=output_dict,
num_frames=self.condition_state["num_frames"],
track_in_reverse=False,
run_mem_encoder=True,
prev_sam_mask_logits=None,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = self.condition_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
if self.fill_hole_area > 0:
pred_masks_gpu = fill_holes_in_mask_scores(
pred_masks_gpu, self.fill_hole_area
)
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(current_out)
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
obj_ptr = current_out["obj_ptr"]
# make a compact version of this frame's output to reduce the state size
current_out = {
"maskmem_features": maskmem_features,
"maskmem_pos_enc": maskmem_pos_enc,
"pred_masks": pred_masks,
"obj_ptr": obj_ptr,
}
# output_dict[storage_key][self.frame_idx] = current_out
self._manage_memory_obj(self.frame_idx, current_out)
_, video_res_masks = self._get_orig_video_res_output(pred_masks_gpu)
return obj_ids, video_res_masks
每帧处理:
- 提取当前帧特征
- 调用
track_step融合历史记忆 - 生成分割掩码
- 更新内存状态
3. 关键优化技术
3.1 内存管理机制
通过 _manage_memory_obj 限制内存占用:
def _manage_memory_obj(self, frame_idx, current_out):
output_dict = self.condition_state["output_dict"]
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
non_cond_frame_outputs[frame_idx] = current_out
key_list = [key for key in output_dict["non_cond_frame_outputs"]]
#! TODO: better way to manage memory
if len(non_cond_frame_outputs) > self.num_maskmem:
for t in range(0, len(non_cond_frame_outputs) - self.num_maskmem):
# key, Value = non_cond_frame_outputs.popitem(last=False)
_ = non_cond_frame_outputs.pop(key_list[t], None)
- 只保留最近
num_maskmem(默认 7)帧的记忆 - 超出时自动删除最旧的帧,避免内存无限增长
3.2 混合精度推理
使用 bfloat16 降低内存和计算量:
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
3.3 CPU/GPU 内存调度
支持将非关键数据 offload 到 CPU:
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
self.condition_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
self.condition_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
self.condition_state["device"] = torch.device("cuda")
if offload_state_to_cpu:
self.condition_state["storage_device"] = torch.device("cpu")
else:
self.condition_state["storage_device"] = torch.device("cuda")
3.4 模型编译优化(可选)
通过 SAM2CameraPredictorVOS 使用 torch.compile 加速:
class SAM2CameraPredictorVOS(SAM2CameraPredictor):
"""Optimized for the VOS setting"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.compile_memory_encoder = kwargs.get("compile_memory_encoder", False)
self.compile_memory_attention = kwargs.get("compile_memory_attention", False)
self.compile_prompt_encoder = kwargs.get("compile_prompt_encoder", False)
self.compile_mask_decoder = kwargs.get("compile_mask_decoder", False)
self._compile_all_components()
def _compile_all_components(self):
print("Compiling all components for VOS setting. First time may be very slow.")
if self.compile_memory_encoder:
print("Compiling memory encoder...")
self.memory_encoder.forward = torch.compile(
self.memory_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
if self.compile_memory_attention:
print("Compiling memory attention...")
self.memory_attention.forward = torch.compile(
self.memory_attention.forward,
mode="max-autotune",
fullgraph=True,
dynamic=True,
)
if self.compile_prompt_encoder:
self.sam_prompt_encoder.forward = torch.compile(
self.sam_prompt_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
if self.compile_mask_decoder:
self.sam_mask_decoder.forward = torch.compile(
self.sam_mask_decoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
- 编译关键组件以提升推理速度
- 首次运行较慢,后续会更快
4. 记忆注意力机制
利用历史帧信息进行跨帧建模:
def _prepare_memory_conditioned_features(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
):
"""Fuse the current frame's visual feature map with previous memory."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
device = current_vision_feats[-1].device
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
# In this case, we skip the fusion with any memory.
if self.num_maskmem == 0: # Disable memory and skip fusion
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
return pix_feat
num_obj_ptr_tokens = 0
tpos_sign_mul = -1 if track_in_reverse else 1
# Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame:
# Retrieve the memories encoded with the maskmem backbone
to_cat_memory, to_cat_memory_pos_embed = [], []
# Add conditioning frames's output first (all cond frames have t_pos=0 for
# when getting temporal positional embedding below)
assert len(output_dict["cond_frame_outputs"]) > 0
# Select a maximum number of temporally closest cond frames for cross attention
cond_outputs = output_dict["cond_frame_outputs"]
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
frame_idx, cond_outputs, self.max_cond_frames_in_attn
)
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with stride>1), in which case
# we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
stride = 1 if self.training else self.memory_temporal_stride_for_eval
for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1:
# for t_rel == 1, we take the last frame (regardless of r)
if not track_in_reverse:
# the frame immediately before this frame (i.e. frame_idx - 1)
prev_frame_idx = frame_idx - t_rel
else:
# the frame immediately after this frame (i.e. frame_idx + 1)
prev_frame_idx = frame_idx + t_rel
else:
# for t_rel >= 2, we take the memory frame from every r-th frames
if not track_in_reverse:
# first find the nearest frame among every r-th frames before this frame
# for r=1, this would be (frame_idx - 2)
prev_frame_idx = ((frame_idx - 2) // stride) * stride
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
else:
# first find the nearest frame among every r-th frames after this frame
# for r=1, this would be (frame_idx + 2)
prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
if out is None:
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
# frames, we still attend to it as if it's a non-conditioning frame.
out = unselected_cond_outputs.get(prev_frame_idx, None)
t_pos_and_prevs.append((t_pos, out))
for t_pos, prev in t_pos_and_prevs:
if prev is None:
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
)
to_cat_memory_pos_embed.append(maskmem_enc)
# Construct the list of past object pointers
if self.use_obj_ptrs_in_encoder:
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
# First add those object pointers from selected conditioning frames
# (optionally, only include object pointers in the past during evaluation)
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
ptr_cond_outputs = {
t: out
for t, out in selected_cond_outputs.items()
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
}
else:
ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame
(
(
(frame_idx - t) * tpos_sign_mul
if self.use_signed_tpos_enc_to_obj_ptrs
else abs(frame_idx - t)
),
out["obj_ptr"],
)
for t, out in ptr_cond_outputs.items()
]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
for t_diff in range(1, max_obj_ptrs_in_encoder):
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
if t < 0 or (num_frames is not None and t >= num_frames):
break
out = output_dict["non_cond_frame_outputs"].get(
t, unselected_cond_outputs.get(t, None)
)
if out is not None:
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
# If we have at least one object pointer, add them to the across attention
if len(pos_and_ptrs) > 0:
pos_list, ptrs_list = zip(*pos_and_ptrs)
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
obj_ptrs = torch.stack(ptrs_list, dim=0)
# a temporal positional embedding based on how far each object pointer is from
# the current frame (sine embedding normalized by the max pointer num).
if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list).to(
device=device, non_blocking=True
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
else:
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
if self.mem_dim < C:
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
obj_ptrs = obj_ptrs.reshape(
-1, B, C // self.mem_dim, self.mem_dim
)
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
to_cat_memory.append(obj_ptrs)
to_cat_memory_pos_embed.append(obj_pos)
num_obj_ptr_tokens = obj_ptrs.shape[0]
else:
num_obj_ptr_tokens = 0
else:
# for initial conditioning frames, encode them without using any previous memory
if self.directly_add_no_mem_embed:
# directly add no-mem embedding (instead of using the transformer encoder)
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
# Step 2: Concatenate the memories and forward through the transformer encoder
memory = torch.cat(to_cat_memory, dim=0)
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats,
curr_pos=current_vision_pos_embeds,
memory=memory,
memory_pos=memory_pos_embed,
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
# reshape the output (HW)BC => BCHW
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
- 融合历史帧的记忆特征
- 使用时间位置编码维护时序信息
- 通过注意力机制进行跨帧信息交互
5. 使用示例
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
width, height = frame.shape[:2][::-1]
if not if_init:
predictor.load_first_frame(frame)
if_init = True
ann_frame_idx = 0 # the frame index we interact with
# First annotation
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
##! add points, `1` means positive click and `0` means negative click
points = np.array([[600, 255]], dtype=np.float32)
labels = np.array([1], dtype=np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
)
# Second annotation
ann_obj_id = 2
## ! add bbox
bbox = np.array([[600, 214], [765, 286]], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox
)
总结
实时推理实现的关键点:
- 逐帧处理:对每一帧独立处理,避免批处理带来的延迟
- 记忆机制:维护固定大小的历史记忆,用于跨帧信息融合
- 内存管理:限制存储帧数,支持 CPU offload
- 精度优化:使用 bfloat16 降低计算和内存开销
- 模型编译:可选使用 torch.compile 进一步加速
- 特征缓存:缓存已计算的特征,避免重复计算
这些优化使得项目能在保持精度的同时实现实时性能。
1346

被折叠的 条评论
为什么被折叠?



