在Part1中,我们详细介绍了SAM2模型的微调训练流程。
SAM2模型微调训练、验证和预测(Part1)_sam2微调-优快云博客
本文将重点讲解模型验证和预测的实现方法。
包含以下内容:
1. 模型验证:在有掩码标签的情况下评估模型性能
2. 零样本预测:对全新图像进行无监督预测
一、模型验证流程
验证阶段使用带标注的测试集数据,通过计算IoU(交并比)和CPA(正确像素准确率)量化模型性能。以下是关键步骤解析:
1. 数据准备与加载
# 配置路径
data_dir = "dataset"
images_dir = os.path.join(data_dir, "images")
masks_dir = os.path.join(data_dir, "masks")
train_csv = os.path.join(data_dir, "train.csv")
model_cfg = r"configs\sam2.1\sam2.1_hiera_l.yaml"
checkpoint_path = r"checkpoints\sam2.1_hiera_large.pt"
finetuned_weights = r"weights\best.pt" # 训练后的模型文件名
- 数据集要求与训练集一致,包含
images/
、masks/
目录和train.csv
索引文件 - 读取sam2必要配置文件和模型训练保存的模型文件
2. 读取图像并调整分辨率
def read_image(image_path, mask_path):
img = cv2.imread(image_path)[..., ::-1]
mask = cv2.imread(mask_path, 0)
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
img = cv2.resize(img, (int(img.shape[1]*r), int(img.shape[0]*r)))
mask = cv2.resize(mask, (int(mask.shape[1]*r), int(mask.shape[0]*r)), interpolation=cv2.INTER_NEAREST)
return img, mask
- 掩码必须使用
INTER_NEAREST
插值,避免插值产生无效类别值 - 图像与掩码需同步缩放,确保空间对齐
3. 提示点生成
def get_points(mask, num_points=30):
coords = np.argwhere(mask > 0) # 获取所有前景像素坐标
points = []
for _ in range(num_points):
yx = coords[np.random.randint(len(coords))] # 随机选择前景点
points.append([[yx[1], yx[0]]]) # 转换为(x,y)格式
return np.array(points)
- 真实掩码区域采样更符合实际应用场景,但是必须依赖掩码文件。
- 复杂场景可以增加点数(如50-100),简单场景可减少(10-20)
4. 模型加载与推理
# 加载微调后的模型
sam2_model = build_sam2(model_cfg, checkpoint_path, device="cuda")
predictor = SAM2ImagePredictor(sam