基本上按照官方的指示进行就可以了,记录其中出现的错误。
0. 如果有已经成功的 conda 环境,先创建 conda 环境
conda 镜像源设置脚本
#!/usr/bin/env bash
cat > ~/.condarc <<'EOF'
channels:
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
- defaults
show_channel_urls: true
channel_priority: flexible
EOF
# 接受 main 频道条款
conda tos accept --override-channels \
--channel https://repo.anaconda.com/pkgs/main
# 若也用到 r 包,再接受 r 频道
conda tos accept --override-channels \
--channel https://repo.anaconda.com/pkgs/r
使用文件创建 sam2 的运行环境。
conda env create -f environment.SAM2.yaml -n sam2
conda activate sam2
需要注意的是,尽管有了 environment.yaml 文件, conda 环境也有可能因为内部深层次的版本冲突导致安装失败。所以这并不是一个万能的方法。这个时候需要按照步骤来,首先是创建 conda 环境:
conda create -n sam2 python==3.10 -y
1. 拉取代码然后安装环境
git clone https://github.com/facebookresearch/sam2.git && cd sam2
pip install -e .
2. 下载模型的权重文件
cd checkpoints && \
./download_ckpts.sh && \
cd ..
3. 在 sam2 的根目录下面创建 run.py 测试模型的功能
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
# 加载图像
img = Image.open("./assets/sa_v_dataset.jpg").convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
# 设置图像
predictor.set_image(img)
# 获取图像尺寸
width, height = img.size
# 在图像中心创建一个点提示
point_coords = np.array([[width//2, height//2]])
point_labels = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=point_coords,
point_labels=point_labels
)
# 可视化结果
print(f"生成了 {len(masks)} 个掩码")
print(f"置信度分数: {scores}")
# 显示原始图像和分割结果
plt.figure(figsize=(15, 5))
# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(img)
plt.title("Original Image")
plt.axis('off')
# 最佳掩码
best_mask_idx = np.argmax(scores)
best_mask = masks[best_mask_idx]
plt.subplot(1, 3, 2)
plt.imshow(best_mask, cmap='gray')
plt.title(f"Best Mask (Score: {scores[best_mask_idx]:.3f})")
plt.axis('off')
# 叠加显示
plt.subplot(1, 3, 3)
plt.imshow(img)
plt.imshow(best_mask, alpha=0.5, cmap='viridis')
plt.title("Overlay")
plt.axis('off')
plt.tight_layout()
plt.show()
# 保存结果
result_img = Image.fromarray((best_mask * 255).astype(np.uint8))
result_img.save("segmentation_result.png")
print("结果已保存为 segmentation_result.png")
4. 在 sam2 的根目录下面创建 run_v.py 测试模型处理视频的功能
import torch
from sam2.build_sam import build_sam2_video_predictor
import numpy as np
import os
def check_dependencies():
"""检查必要的依赖是否安装"""
try:
import decord
print("✓ decord 已安装")
except ImportError:
print("✗ 错误: 未安装 decord 库")
print("请运行: pip install decord")
return False
return True
def simple_video_segmentation(video_path):
"""
简化的视频分割函数 - 修正版本
"""
# 检查依赖
if not check_dependencies():
return
# 检查视频文件是否存在
if not os.path.exists(video_path):
print(f"✗ 错误: 视频文件不存在: {video_path}")
return
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
try:
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
print("✓ SAM2视频预测器加载成功")
except Exception as e:
print(f"✗ 模型加载失败: {e}")
return
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
try:
# 初始化视频状态
print("正在初始化视频...")
state = predictor.init_state(video_path)
print(f"✓ 视频初始化完成")
# state 是一个字典,我们需要查看它的结构
print(f"State 类型: {type(state)}")
print(f"State 键: {state.keys() if isinstance(state, dict) else 'N/A'}")
# 获取视频信息
if isinstance(state, dict):
num_frames = state.get('num_frames', 0)
print(f"总帧数: {num_frames}")
# 获取第一帧(如果有的话)
if 'frames' in state and len(state['frames']) > 0:
first_frame = state['frames'][0]
height, width = first_frame.shape[:2]
print(f"视频尺寸: {width}x{height}")
else:
# 如果没有frames信息,使用默认尺寸或从视频文件获取
import cv2
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
print(f"视频尺寸: {width}x{height}")
else:
print("State 不是字典类型")
return
# 在第一帧中心添加点提示
prompt_coords = np.array([[width//2, height//2]])
prompt_labels = np.array([1])
print("添加分割提示...")
# 添加提示 - 注意:这里可能需要调整参数
try:
frame_idx, object_ids, masks = predictor.add_new_points_or_box(
state,
point_coords=prompt_coords,
point_labels=prompt_labels
)
print(f"✓ 生成 {len(object_ids)} 个分割对象")
except Exception as e:
print(f"✗ 添加提示失败: {e}")
# 尝试其他参数组合
try:
frame_idx, object_ids, masks = predictor.add_new_points_or_box(
state,
points=prompt_coords,
labels=prompt_labels
)
print(f"✓ 生成 {len(object_ids)} 个分割对象")
except Exception as e2:
print(f"✗ 所有提示方式都失败: {e2}")
return
print(f"开始传播分割...")
# 传播到整个视频
processed_frames = 0
try:
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
processed_frames += 1
if processed_frames % 10 == 0: # 每10帧打印一次进度
print(f"已处理 {processed_frames} 帧")
print(f"✓ 视频分割完成! 总共处理 {processed_frames} 帧")
except Exception as e:
print(f"✗ 传播过程中出错: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print(f"✗ 处理过程中出错: {e}")
import traceback
traceback.print_exc()
# 调试版本:查看state结构
def debug_video_segmentation(video_path):
"""
调试版本,查看state的具体结构
"""
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(video_path)
print("=== State 结构分析 ===")
print(f"State 类型: {type(state)}")
if isinstance(state, dict):
for key, value in state.items():
print(f"{key}: {type(value)}")
if hasattr(value, 'shape'):
print(f" shape: {value.shape}")
elif isinstance(value, (list, tuple)) and len(value) > 0:
print(f" length: {len(value)}")
if hasattr(value[0], 'shape'):
print(f" first element shape: {value[0].shape}")
elif isinstance(value, (int, float, str, bool)):
print(f" value: {value}")
else:
print(f" value type: {type(value)}")
else:
print("State 的属性:")
for attr in dir(state):
if not attr.startswith('_'):
try:
value = getattr(state, attr)
print(f"{attr}: {type(value)}")
except:
print(f"{attr}: <无法获取>")
return state
# 使用
if __name__ == "__main__":
video_path = "./demo/data/gallery/01_dog.mp4"
# 先运行调试版本查看state结构
print("=== 调试状态 ===")
state = debug_video_segmentation(video_path)
print("\n=== 开始分割 ===")
simple_video_segmentation(video_path)
5. 问题记录
1. pip install -e . 的执行过程中(这其实是整个过程中最困难的部分了)
按照官网的指示,在 sam2 中的根目录下面运行 pip install -e . 的时候出现如下问题 安装的时候卡在 Downloading torch-2.9.1-cp310-cp310-manylinux_2_28_x86_64.whl (899.8 MB) 怎么办。一个好的解决方案就是手动安装卡住的这个包,并且在手动安装的时候指定镜像源:pip install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple 如此完成单独依赖的安装之后,接下只需要再次执行 pip install -e . --verbose 即可。
2. 脚本运行时候额外的依赖安装
例如,分割图像的时候额外安装了 matplotlib 这个包;而在处理视频的脚本中可能还需要安装 opencv 这些。
1199

被折叠的 条评论
为什么被折叠?



