Sam2复现过程

基本上按照官方的指示进行就可以了,记录其中出现的错误。

0. 如果有已经成功的 conda 环境,先创建 conda 环境

conda 镜像源设置脚本

#!/usr/bin/env bash
cat > ~/.condarc <<'EOF'
channels:
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
  - defaults
show_channel_urls: true
channel_priority: flexible
EOF

# 接受 main 频道条款
conda tos accept --override-channels \
  --channel https://repo.anaconda.com/pkgs/main

# 若也用到 r 包,再接受 r 频道
conda tos accept --override-channels \
  --channel https://repo.anaconda.com/pkgs/r

使用文件创建 sam2 的运行环境。

conda env create -f environment.SAM2.yaml -n sam2
conda activate sam2

需要注意的是,尽管有了 environment.yaml 文件, conda 环境也有可能因为内部深层次的版本冲突导致安装失败。所以这并不是一个万能的方法。这个时候需要按照步骤来,首先是创建 conda 环境:conda create -n sam2 python==3.10 -y

1. 拉取代码然后安装环境

git clone https://github.com/facebookresearch/sam2.git && cd sam2

pip install -e .

2. 下载模型的权重文件

cd checkpoints && \
./download_ckpts.sh && \
cd ..

3. 在 sam2 的根目录下面创建 run.py 测试模型的功能

import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

# 加载图像
img = Image.open("./assets/sa_v_dataset.jpg").convert("RGB")

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    # 设置图像
    predictor.set_image(img)
    
    # 获取图像尺寸
    width, height = img.size
    
    # 在图像中心创建一个点提示
    point_coords = np.array([[width//2, height//2]])
    point_labels = np.array([1])
    
    masks, scores, logits = predictor.predict(
        point_coords=point_coords, 
        point_labels=point_labels
    )
    
    # 可视化结果
    print(f"生成了 {len(masks)} 个掩码")
    print(f"置信度分数: {scores}")
    
    # 显示原始图像和分割结果
    plt.figure(figsize=(15, 5))
    
    # 原始图像
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis('off')
    
    # 最佳掩码
    best_mask_idx = np.argmax(scores)
    best_mask = masks[best_mask_idx]
    plt.subplot(1, 3, 2)
    plt.imshow(best_mask, cmap='gray')
    plt.title(f"Best Mask (Score: {scores[best_mask_idx]:.3f})")
    plt.axis('off')
    
    # 叠加显示
    plt.subplot(1, 3, 3)
    plt.imshow(img)
    plt.imshow(best_mask, alpha=0.5, cmap='viridis')
    plt.title("Overlay")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # 保存结果
    result_img = Image.fromarray((best_mask * 255).astype(np.uint8))
    result_img.save("segmentation_result.png")
    print("结果已保存为 segmentation_result.png")

4. 在 sam2 的根目录下面创建 run_v.py 测试模型处理视频的功能

import torch
from sam2.build_sam import build_sam2_video_predictor
import numpy as np
import os

def check_dependencies():
    """检查必要的依赖是否安装"""
    try:
        import decord
        print("✓ decord 已安装")
    except ImportError:
        print("✗ 错误: 未安装 decord 库")
        print("请运行: pip install decord")
        return False
    return True

def simple_video_segmentation(video_path):
    """
    简化的视频分割函数 - 修正版本
    """
    # 检查依赖
    if not check_dependencies():
        return
    
    # 检查视频文件是否存在
    if not os.path.exists(video_path):
        print(f"✗ 错误: 视频文件不存在: {video_path}")
        return
    
    checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
    model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
    
    try:
        predictor = build_sam2_video_predictor(model_cfg, checkpoint)
        print("✓ SAM2视频预测器加载成功")
    except Exception as e:
        print(f"✗ 模型加载失败: {e}")
        return

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        try:
            # 初始化视频状态
            print("正在初始化视频...")
            state = predictor.init_state(video_path)
            print(f"✓ 视频初始化完成")
            
            # state 是一个字典,我们需要查看它的结构
            print(f"State 类型: {type(state)}")
            print(f"State 键: {state.keys() if isinstance(state, dict) else 'N/A'}")
            
            # 获取视频信息
            if isinstance(state, dict):
                num_frames = state.get('num_frames', 0)
                print(f"总帧数: {num_frames}")
                
                # 获取第一帧(如果有的话)
                if 'frames' in state and len(state['frames']) > 0:
                    first_frame = state['frames'][0]
                    height, width = first_frame.shape[:2]
                    print(f"视频尺寸: {width}x{height}")
                else:
                    # 如果没有frames信息,使用默认尺寸或从视频文件获取
                    import cv2
                    cap = cv2.VideoCapture(video_path)
                    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    cap.release()
                    print(f"视频尺寸: {width}x{height}")
            else:
                print("State 不是字典类型")
                return
            
            # 在第一帧中心添加点提示
            prompt_coords = np.array([[width//2, height//2]])
            prompt_labels = np.array([1])
            
            print("添加分割提示...")
            # 添加提示 - 注意:这里可能需要调整参数
            try:
                frame_idx, object_ids, masks = predictor.add_new_points_or_box(
                    state, 
                    point_coords=prompt_coords,
                    point_labels=prompt_labels
                )
                print(f"✓ 生成 {len(object_ids)} 个分割对象")
            except Exception as e:
                print(f"✗ 添加提示失败: {e}")
                # 尝试其他参数组合
                try:
                    frame_idx, object_ids, masks = predictor.add_new_points_or_box(
                        state, 
                        points=prompt_coords,
                        labels=prompt_labels
                    )
                    print(f"✓ 生成 {len(object_ids)} 个分割对象")
                except Exception as e2:
                    print(f"✗ 所有提示方式都失败: {e2}")
                    return
            
            print(f"开始传播分割...")
            
            # 传播到整个视频
            processed_frames = 0
            try:
                for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
                    processed_frames += 1
                    if processed_frames % 10 == 0:  # 每10帧打印一次进度
                        print(f"已处理 {processed_frames} 帧")
                
                print(f"✓ 视频分割完成! 总共处理 {processed_frames} 帧")
                
            except Exception as e:
                print(f"✗ 传播过程中出错: {e}")
                import traceback
                traceback.print_exc()
            
        except Exception as e:
            print(f"✗ 处理过程中出错: {e}")
            import traceback
            traceback.print_exc()

# 调试版本:查看state结构
def debug_video_segmentation(video_path):
    """
    调试版本,查看state的具体结构
    """
    checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
    model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
    
    predictor = build_sam2_video_predictor(model_cfg, checkpoint)

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        state = predictor.init_state(video_path)
        
        print("=== State 结构分析 ===")
        print(f"State 类型: {type(state)}")
        
        if isinstance(state, dict):
            for key, value in state.items():
                print(f"{key}: {type(value)}")
                if hasattr(value, 'shape'):
                    print(f"  shape: {value.shape}")
                elif isinstance(value, (list, tuple)) and len(value) > 0:
                    print(f"  length: {len(value)}")
                    if hasattr(value[0], 'shape'):
                        print(f"  first element shape: {value[0].shape}")
                elif isinstance(value, (int, float, str, bool)):
                    print(f"  value: {value}")
                else:
                    print(f"  value type: {type(value)}")
        else:
            print("State 的属性:")
            for attr in dir(state):
                if not attr.startswith('_'):
                    try:
                        value = getattr(state, attr)
                        print(f"{attr}: {type(value)}")
                    except:
                        print(f"{attr}: <无法获取>")
        
        return state

# 使用
if __name__ == "__main__":
    video_path = "./demo/data/gallery/01_dog.mp4"
    
    # 先运行调试版本查看state结构
    print("=== 调试状态 ===")
    state = debug_video_segmentation(video_path)
    
    print("\n=== 开始分割 ===")
    simple_video_segmentation(video_path)

5. 问题记录

1. pip install -e . 的执行过程中(这其实是整个过程中最困难的部分了)

按照官网的指示,在 sam2 中的根目录下面运行 pip install -e . 的时候出现如下问题 安装的时候卡在 Downloading torch-2.9.1-cp310-cp310-manylinux_2_28_x86_64.whl (899.8 MB) 怎么办。一个好的解决方案就是手动安装卡住的这个包,并且在手动安装的时候指定镜像源:pip install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple 如此完成单独依赖的安装之后,接下只需要再次执行 pip install -e . --verbose 即可。

2. 脚本运行时候额外的依赖安装

例如,分割图像的时候额外安装了 matplotlib 这个包;而在处理视频的脚本中可能还需要安装 opencv 这些。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值