SAM2模型微调训练、验证和预测(Part2)

        在Part1中,我们详细介绍了SAM2模型的微调训练流程。

SAM2模型微调训练、验证和预测(Part1)_sam2微调-优快云博客

        本文将重点讲解模型验证和预测的实现方法。

        包含以下内容:

        1. 模型验证​​:在有掩码标签的情况下评估模型性能

        2. 零样本预测​​:对全新图像进行无监督预测

 

一、模型验证流程

        验证阶段使用带标注的测试集数据,通过计算IoU(交并比)和CPA(正确像素准确率)量化模型性能。以下是关键步骤解析:

1. 数据准备与加载

# 配置路径
data_dir = "dataset"
images_dir = os.path.join(data_dir, "images")
masks_dir = os.path.join(data_dir, "masks")
train_csv = os.path.join(data_dir, "train.csv")

model_cfg = r"configs\sam2.1\sam2.1_hiera_l.yaml"
checkpoint_path = r"checkpoints\sam2.1_hiera_large.pt"
finetuned_weights = r"weights\best.pt"  # 训练后的模型文件名
  • 数据集要求与训练集一致,包含images/masks/目录和train.csv索引文件
  • 读取sam2必要配置文件和模型训练保存的模型文件

2. 读取图像并调整分辨率

def read_image(image_path, mask_path):
    img = cv2.imread(image_path)[..., ::-1]
    mask = cv2.imread(mask_path, 0)
    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
    img = cv2.resize(img, (int(img.shape[1]*r), int(img.shape[0]*r)))
    mask = cv2.resize(mask, (int(mask.shape[1]*r), int(mask.shape[0]*r)), interpolation=cv2.INTER_NEAREST)
    return img, mask
  • 掩码必须使用INTER_NEAREST插值,避免插值产生无效类别值
  • 图像与掩码需同步缩放,确保空间对齐

3. 提示点生成

def get_points(mask, num_points=30):
    coords = np.argwhere(mask > 0)  # 获取所有前景像素坐标
    points = []
    for _ in range(num_points):
        yx = coords[np.random.randint(len(coords))]  # 随机选择前景点
        points.append([[yx[1], yx[0]]])  # 转换为(x,y)格式
    return np.array(points)
  • 真实掩码区域采样更符合实际应用场景,但是必须依赖掩码文件。
  • 复杂场景可以增加点数(如50-100),简单场景可减少(10-20)

4. 模型加载与推理

# 加载微调后的模型
sam2_model = build_sam2(model_cfg, checkpoint_path, device="cuda")
predictor = SAM2ImagePredictor(sam
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值