SAM微调fine-tune/PEFT系列论文整理

部署运行你感兴趣的模型镜像


SAM分割一切(论文解析)提出了新的任务、模型、数据集,实现了很好的通用分割能力,但对一些细分领域的数据适配仍然不是很好,因此需要进行微调,本篇重点整理一下与SAM微调相关的工作。

  • 由于SAM是通用分割模型,在微调时通常要保留其强大的通用能力不被破坏,所以几乎所有论文微调方式与目前的大模型高效微调方式思路一致,主要采用lora、adapter等技术
  • 部分博客通过仅微调mask decoder实现领域数据适配

SAM主要由image encoder、prompt encoder、mask decoder三部分组成,因此在微调的时候也主要是围绕这3部分进行相关工作。
在这里插入图片描述

peft-sam: Parameter Efficient Fine-Tuning of Segment Anything Models - 在生物医学成像领域为SAM实现了几种参数高效微调(PEFT)方法

paper:https://arxiv.org/abs/2502.00418
code:https://github.com/computational-cell-analytics/peft-sam

Lightning Segment-Anything Model(2023) - 仅代码,需要看代码确定其微调的具体是如何实现的

code:https://github.com/luca-medeiros/lightning-sam

  • 只支持标注框promt

该库允许你针对自定义的COCO格式数据集,对MetaAI强大的Segment-Anything模型进行微调。该库基于Lightning AI的Fabric框架构建,为实现最先进的实例分割结果提供了高效且易于使用的实现方式。
这个代码库是一次实验;是一个概念验证,旨在探究使用边界框作为提示来微调SAM,是否能总体上提高交并比(IoU)或改善掩码的质量。用户可以使用COCO格式的数据集,针对SAM表现不佳的特定任务(例如,分割文档上的文本)对SAM进行微调,然后像使用SAM一样,结合交互式提示使用该模型。

SAM-Adapter: Adapting SAM in Underperformed Scenes: Camouflage, Shadow, Medical Image Segmentation, and More(2023) - 通过Adapter实现微调,不需要标注prompt

  • 但项目中提示显存可能占用很大,需准备好资源……

code:https://github.com/tianrun-chen/SAM-Adapter-PyTorch
SAM在某些分割任务中可能会失败或表现不佳,例如阴影检测和伪装物体检测(隐蔽物体检测)。本研究首次为将大型预训练图像分割模型SAM应用于这些下游任务铺平了道路,即使在SAM表现不佳的情况下也是如此。我们没有对SAM网络进行微调,而是提出了SAM - Adapter,它通过使用简单而有效的适配器,将特定领域信息或视觉提示融入到分割网络中。通过将特定任务知识与大模型学到的通用知识相结合,如大量实验所示,SAM - 适配器可以显著提升SAM在具有挑战性任务中的性能。我们甚至超越了特定任务的网络模型,并在我们测试的任务(伪装物体检测、阴影检测)中取得了最先进的性能。我们还测试了息肉分割(医学图像分割),并取得了更好的结果。我们相信,我们的工作为在下游任务中利用SAM开辟了机会,其在包括医学图像处理、农业、遥感等各个领域都有潜在应用。

Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation(2023)- 通过adapter实现医学数据微调,需要prompt标注

paper:https://arxiv.org/pdf/2304.12620
code:https://github.com/SuperMedIntel/Medical-SAM-Adapter,支持SAM/mobileSAM/EfficientSAM

由于缺乏特定的医学领域知识,SAM在医学图像分割任务中表现欠佳。这引发了如何提升SAM对医学图像分割能力的问题。本文提出了医学SAM适配器(Med-SA),通过轻量而高效的适配技术将特定领域的医学知识融入分割模型。在Med-SA中,我们设计了空间深度转置模块(SD-Trans)以将2D SAM适配于3D医学图像,并提出超提示适配器(HyP-Adpt)实现提示条件下的适配。我们在17项跨多种成像模态的医学图像分割任务上进行了全面评估实验,结果表明Med-SA在仅更新2%模型参数的情况下,性能超越了多个最先进的(SOTA)医学图像分割方法。

finetune-anything(2023)- 支持较高程度的定制化,可以自定义不同位置是否使用adapter

code:https://github.com/ziqi-jin/finetune-anything
分割一切模型(SAM)彻底改变了计算机视觉领域。依靠对SAM的微调将解决大量基础计算机视觉任务。我们正在设计一个基于SAM的用于训练微调模型的类别感知单阶段工具。
你需要提供任务所需的数据集以及支持的任务名称,此工具将帮助你获得针对任务的微调模型。你也可以设计自己的扩展SAM模型,FA将为你提供训练、测试和部署流程。

Customized Segment Anything Model for Medical Image Segmentation(2023)- 通过lora实现医学数据的微调,且是无提示的自动语义分割,不需要标注prompt,但也丢失了SAM的交互能力

paper:https://arxiv.org/pdf/2304.13785
code:https://github.com/hitachinsk/SAMed

我们提出了SAMed,这是一种用于医学图像分割的通用解决方案。与以往方法不同,SAMed基于大规模图像分割模型“分割一切模型”(Segment Anything Model,简称SAM)构建,旨在探索为医学图像分割定制大规模模型的新研究范式。SAMed将基于低秩的(LoRA)微调策略应用于SAM图像编码器,并在有标注的医学图像分割数据集上,与提示编码器和掩码解码器一起对其进行微调。我们还观察到,预热微调策略和AdamW优化器使SAMed成功收敛并降低了损失。与SAM不同,SAMed可以对医学图像执行语义分割。我们训练的SAMed模型在Synapse多器官分割数据集上达到了81.88的DSC(Dice相似系数)和20.64的HD(Hausdorff距离),与当前最先进的方法相当。我们进行了大量实验来验证我们设计的有效性。由于SAMed只更新了SAM参数的一小部分,在实际应用中,其部署成本和存储成本相当低。
在这里插入图片描述

[NIPS2023] HQ-SAM:Segment Anything in High Quality - 增加额外的token调整生成的mask

苏黎世联邦理工、香港科技大学
paper:https://arxiv.org/abs/2306.01567
code:https://github.com/SysCV/sam-hq

尽管SAM在训练时使用了11亿个掩码,但在许多情况下,其掩码预测质量仍存在不足,尤其是在处理结构复杂的物体时。我们提出了HQ-SAM,在保留SAM原有的可提示设计、效率和零样本泛化能力的同时,赋予其精确分割任意物体的能力。我们的设计精心复用并保留了SAM的预训练模型权重,仅引入了极少的额外参数和计算量。我们设计了一个可学习的高质量输出token,将其注入SAM的掩码解码器中,专门负责预测高质量掩码。不同于仅在掩码解码器特征上应用,我们首先将这些特征与视觉Transformer(ViT)的早期和最终特征进行融合,以改善掩码细节。为了训练我们引入的可学习参数,我们从多个来源构建了包含4.4万个细粒度掩码的数据集。HQ-SAM仅在这个包含4.4万掩码的数据集上进行训练,使用8块GPU仅需4小时。我们在10个不同的下游任务分割数据集上验证了HQ-SAM的有效性,其中8个数据集通过零样本迁移协议进行评估。

[ICME2024]PA-SAM: Prompt Adapter SAM for High-quality Image Segmentation - 通过Adapter微调

paper:https://arxiv.org/abs/2401.13051
code:https://github.com/xzz2/pa-sam

我们将一种新颖的提示驱动适配器引入到SAM中,即提示适配器分割一切模型(PA - SAM),旨在提升原始SAM的分割掩码质量。通过仅对提示适配器进行训练,PA - SAM从图像中提取详细信息,并在稀疏和密集提示层面优化掩码解码器特征,从而提高SAM的分割性能以生成高质量掩码。实验结果表明,我们的PA - SAM在高质量、零样本和开放集分割方面优于其他基于SAM的方法。

参考:
https://zhuanlan.zhihu.com/p/622677489
https://encord.com/blog/learn-how-to-fine-tune-the-segment-anything-model-sam/
https://zhuanlan.zhihu.com/p/627098441

您可能感兴趣的与本文相关的镜像

Seed-Coder-8B-Base

Seed-Coder-8B-Base

文本生成
Seed-Coder

Seed-Coder是一个功能强大、透明、参数高效的 8B 级开源代码模型系列,包括基础变体、指导变体和推理变体,由字节团队开源

(medsam) root@autodl-container-b4de43b42c-3e267058:~/autodl-tmp/MedSAM_PEFT# python medsam_box_eval_from_csv_corp.py --csv_in /root/autodl-tmp/MedSAM_PEFT/ISIC_val_boxes.csv --csv_out /root/autodl-tmp/MedSAM_PEFT/ISIC_val_list_with_metrics.csv --base_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/medsam_vit_b.pth \ --finetune_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth \ --model_type vit_b --device cuda:0 --save_pred_dir /root/autodl-tmp/MedSAM_PEFT/train_val_test/val/preds 当前文件绝对路径: /root/autodl-tmp/MedSAM_PEFT/train_one_gpu_ISIC2018.py 当前工作目录: /root/autodl-tmp/MedSAM_PEFT usage: medsam_box_eval_from_csv_corp.py [-h] [-i TR_NPY_PATH] [-task_name TASK_NAME] [-model_type MODEL_TYPE] [-checkpoint CHECKPOINT] [--load_pretrain LOAD_PRETRAIN] [-pretrain_model_path PRETRAIN_MODEL_PATH] [-work_dir WORK_DIR] [-num_epochs NUM_EPOCHS] [-batch_size BATCH_SIZE] [-num_workers NUM_WORKERS] [-weight_decay WEIGHT_DECAY] [-lr LR] [-use_wandb USE_WANDB] [-use_tensorboard USE_TENSORBOARD] [-use_amp] [--resume RESUME] [--device DEVICE] medsam_box_eval_from_csv_corp.py: error: unrecognized arguments: --csv_in /root/autodl-tmp/MedSAM_PEFT/ISIC_val_boxes.csv --csv_out /root/autodl-tmp/MedSAM_PEFT/ISIC_val_list_with_metrics.csv --base_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/medsam_vit_b.pth --finetune_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth --model_type vit_b --save_pred_dir /root/autodl-tmp/MedSAM_PEFT/train_val_test/val/preds
最新发布
10-11
# -*- coding: utf-8 -*- """ Batch MedSAM/SAM box-prompt evaluation from CSV. Each row provides: image_path,mask_path,x1,y1,x2,y2[,Dice,IoU,Pixel_Acc] We compute Dice/IoU/PixelAcc and append them as the last 3 columns. Usage: python medsam_box_eval_from_csv.py \ --csv_in /path/to/list.csv \ --csv_out /path/to/list_with_metrics.csv \ --checkpoint /path/to/medsam_vit_b.pth \ --model_type vit_b \ --device cuda:0 \ --save_pred_dir /path/to/preds # optional Notes: - Checkpoint must be in SAM format (e.g., medsam_vit_b.pth). - Header in CSV is optional. If present, include at least: image_path,mask_path,x1,y1,x2,y2 """ import argparse import csv import os from pathlib import Path from typing import List, Tuple import numpy as np from skimage import io from skimage.color import gray2rgb from segment_anything import sam_model_registry, SamPredictor def to_uint8_rgb(img: np.ndarray) -> np.ndarray: if img.ndim == 2: img = gray2rgb(img) elif img.ndim == 3 and img.shape[2] == 4: img = img[..., :3] elif img.ndim == 3 and img.shape[2] > 4: img = img[..., :3] if img.dtype != np.uint8: img = img.astype(np.float32) mn, mx = float(img.min()), float(img.max()) if mx > mn: img = (255.0 * (img - mn) / (mx - mn)).clip(0, 255).astype(np.uint8) else: img = np.zeros_like(img, dtype=np.uint8) return img def binarize_mask(mask: np.ndarray) -> np.ndarray: if mask.ndim == 3: if mask.shape[2] > 1: mask = (mask > 0).any(axis=2).astype(mask.dtype) else: mask = mask[..., 0] return (mask > 0).astype(np.uint8) def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, float, float]: assert y_true.shape == y_pred.shape, f"Shape mismatch: gt {y_true.shape} vs pred {y_pred.shape}" y_true = (y_true > 0).astype(np.uint8) y_pred = (y_pred > 0).astype(np.uint8) tp = np.logical_and(y_true == 1, y_pred == 1).sum() tn = np.logical_and(y_true == 0, y_pred == 0).sum() fp = np.logical_and(y_true == 0, y_pred == 1).sum() fn = np.logical_and(y_true == 1, y_pred == 0).sum() dice = (2.0 * tp) / (2.0 * tp + fp + fn + 1e-8) iou = tp / (tp + fp + fn + 1e-8) pixel_acc = (tp + tn) / (tp + tn + fp + fn + 1e-8) return float(dice), float(iou), float(pixel_acc) def clip_box_to_image(box: np.ndarray, width: int, height: int) -> np.ndarray: x0, y0, x1, y1 = [float(v) for v in box] x0 = max(0, min(x0, width - 1)) x1 = max(0, min(x1, width - 1)) y0 = max(0, min(y0, height - 1)) y1 = max(0, min(y1, height - 1)) if x1 < x0: x0, x1 = x1, x0 if y1 < y0: y0, y1 = y1, y0 return np.array([x0, y0, x1, y1], dtype=np.float32) def has_header(first_row: List[str]) -> bool: names = [c.strip().lower() for c in first_row] required = ["image_path", "mask_path", "x1", "y1", "x2", "y2"] return all(col in names for col in required) def main(): parser = argparse.ArgumentParser("Batch MedSAM/SAM evaluation from CSV") parser.add_argument("--csv_in", type=str, required=True) parser.add_argument("--csv_out", type=str, required=True) parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--model_type", type=str, default="vit_b", choices=["vit_b", "vit_l", "vit_h"]) parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--save_pred_dir", type=str, default=None) parser.add_argument("--multimask_output", action="store_true") args = parser.parse_args() if args.save_pred_dir: os.makedirs(args.save_pred_dir, exist_ok=True) sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) sam.to(device=args.device) predictor = SamPredictor(sam) #__________________ from segment_anything import sam_model_registry, SamPredictor # 原版 SAM # 你 repo 自带 # ① 先建 SAM 骨干 sam = sam_model_registry["vit_b"](checkpoint=args.bld_ckpt) # 不加载,后面整体加载 # ② 包成 MedSAM(含 decoder) from train_one_gpu_ISIC2018 import MedSAM model = MedSAM( image_encoder=sam.image_encoder, mask_decoder=sam.mask_decoder, prompt_encoder=sam.prompt_encoder, ) # ③ 整体加载(含 decoder 字段) ckpt = torch.load('/root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth', map_location='cpu') # 若 ckpt 是 {"model": ..., "optimizer": ...} 先拆包 print("ckpt keys:", ckpt.keys()) if "model" in ckpt: ckpt = ckpt["model"] model.load_state_dict(ckpt, strict=False) # strict=False 允许字段超集 model = model.to(device) model.eval() img_paths = list_images(args.images_dir) if len(img_paths) == 0: raise FileNotFoundError(f"No images found in {args.images_dir}") # Assume masks share basename with images but may differ by extension def find_mask_for_image(img_path: str) -> str: """ Find corresponding mask for an image. Supports masks that append '_segmentation' to the image id before the extension, e.g.: ISIC_0000000.png -> ISIC_0000000_segmentation.png """ base = os.path.splitext(os.path.basename(img_path))[0] mask_exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff") # 1) Direct basename match for ext in mask_exts: cand = os.path.join(args.masks_dir, base + ext) if os.path.exists(cand): return cand # 2) Suffix '_segmentation' for ext in mask_exts: cand = os.path.join(args.masks_dir, f"{base}_segmentation{ext}") if os.path.exists(cand): return cand # 3) Fallback: scan and compare stem after removing '_segmentation' masks_dir = Path(args.masks_dir) for p in sorted(masks_dir.iterdir()): if p.is_file() and p.suffix.lower() in mask_exts: mask_stem = p.stem.replace('_segmentation', '') if mask_stem == base: return str(p) raise FileNotFoundError(f"No corresponding mask found for {img_path}") all_dice: List[float] = [] all_iou: List[float] = [] all_nsd: List[float] = [] batch_imgs: List[np.ndarray] = [] batch_boxes: List[np.ndarray] = [] batch_gts: List[np.ndarray] = [] batch_names: List[str] = [] from torch.utils.data import Dataset, DataLoader # 你本地的 MedSAM 类 from train_one_gpu_ISIC2018 import NpyDataset # 你训练用的同一个 NpyDataset load_state='eval' # 1. 指向验证集(npy 目录) val_dataset = NpyDataset( data_root="/root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer", load_state=load_state # 新增参数,确保 eval 时不修改 numpy 数组 ) # bbox_shift=0, # 验证时不扰动框 # transform=None # 无增强 # 2. 批量读取(batch_size 就是 args.batch_size) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # 3. 按 batch 推理(复用你原来的 run_batch_inference) all_dice: List[float] = [] all_iou: List[float] = [] all_nsd: List[float] = [] for img_1024, gt2D, bboxes, img_names in val_loader: img_1024 = img_1024.to(device) gt2D = gt2D.to(device) bboxes = bboxes.to(device) # ---- 推理(复用你原来的函数)---- preds = run_batch_inference(model, img_1024, bboxes, device) # List[H,W] len = batch #________________________________________________ with open(args.csv_in, "r", newline="") as f: reader = csv.reader(f) rows = list(reader) if not rows: raise RuntimeError("Empty CSV.") header_present = has_header(rows[0]) if header_present: header = [c.strip() for c in rows[0]] data_rows = rows[1:] else: header = ["image_path", "mask_path", "x1", "y1", "x2", "y2"] data_rows = rows metric_cols = ["Dice", "IoU", "Pixel_Acc"] for col in metric_cols: if col not in header: header.append(col) out_rows = [header] for r in data_rows: if len(r) < 6: continue image_path = r[0].strip() mask_path = r[1].strip() try: x1 = float(r[2]); y1 = float(r[3]); x2 = float(r[4]); y2 = float(r[5]) except Exception: continue if not os.path.isfile(image_path): print(f"[WARN] image not found: {image_path}") out_rows.append(r[:6] + ["nan","nan","nan"]) continue if not os.path.isfile(mask_path): print(f"[WARN] mask not found: {mask_path}") out_rows.append(r[:6] + ["nan","nan","nan"]) continue img = io.imread(image_path); msk = io.imread(mask_path) img = to_uint8_rgb(img); msk = binarize_mask(msk) H, W = img.shape[:2] if msk.shape[:2] != (H, W): print(f"[WARN] shape mismatch (img {img.shape[:2]} vs mask {msk.shape[:2]}). Skipped: {image_path}") out_rows.append(r[:6] + ["nan","nan","nan"]) continue predictor.set_image(img) box = clip_box_to_image(np.array([x1, y1, x2, y2], dtype=np.float32), width=W, height=H) masks, scores, _ = predictor.predict(box=box, multimask_output=args.multimask_output) if masks.ndim == 3: best = int(np.argmax(scores)) if args.multimask_output else 0 pred = masks[best] else: pred = masks pred_bin = (pred > 0).astype(np.uint8) dice, iou, pixel_acc = compute_metrics(msk, pred_bin) if args.save_pred_dir: stem = Path(image_path).stem out_mask = Path(args.save_pred_dir) / f"{stem}.png" io.imsave(str(out_mask), (pred_bin * 255).astype(np.uint8), check_contrast=False) out_rows.append(r[:6] + [f"{dice:.6f}", f"{iou:.6f}", f"{pixel_acc:.6f}"]) with open(args.csv_out, "w", newline="") as f: csv.writer(f).writerows(out_rows) print(f"Done. Wrote {len(out_rows)-1} rows to: {args.csv_out}") if args.save_pred_dir: print(f"Pred masks saved to: {args.save_pred_dir}") if __name__ == "__main__": main() 这是准备运行的一个推理文件,目的是先用csv文件里的box提示来让模型给出mask,然后再批量地把mask和给出的标签进行dice等指标的计算,其中两个#________________________________之间的部分是从另一个文件加入的代码,目的是先用基础的MedSAM_vit-b的权重构建一个sam,然后在此基础上,加载在当前数据集上全量微调过的MedSAM,以此测试微调的结果,另一个文件的代码如下:# -*- coding: utf-8 -*- """ Run batched inference with MedSAM and compute dataset-level Dice and IoU. #datasets/val/npy/ Example: python MedSAM_Inference_v2.py \ --images_dir train_val_test/val/imgs \ --masks_dir train_val_test/val/gts \ --output_dir train_val_test/val/output \ --checkpoint /home/lin/Medical_Img_Segment/MedSAM_PEFT_v1/work_dir/MedSAM/medsam_vit_b.pth \ --batch_size 4 --device cuda:0 --save_preds python /root/autodl-tmp/MedSAM_PEFT/MedSAM_Inference_v2.py \ --images_dir /root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer/imgs \ --masks_dir /root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer/gts \ --output_dir train_val_test/val/output \ --bld_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/medsam_vit_b.pth \ --evl_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth \ --batch_size 8 --device cuda:0 \ # --save_preds /root/autodl-tmp/MedSAM_PEFT/train_val_test/val/save_preds grep -n "add_argument.*--save_preds" MedSAM_Inference_v2.py /root/autodl-tmp/MedSAM_PEFT/MedSAM_Inference_v2.py python MedSAM_Inference_v2.py --help """ import os import argparse from typing import List, Tuple from pathlib import Path import numpy as np from skimage import io, transform from skimage.morphology import binary_erosion from scipy.ndimage import distance_transform_edt import torch import torch.nn.functional as F from segment_anything import sam_model_registry def ensure_numpy(array_like): """ Convert input to numpy array on CPU without modifying values. Supports numpy arrays and torch tensors. """ if isinstance(array_like, np.ndarray): return array_like if isinstance(array_like, torch.Tensor): return array_like.detach().cpu().numpy() raise TypeError(f"Unsupported type: {type(array_like)}") def binarize_mask(mask: np.ndarray, threshold: float = 0.5) -> np.ndarray: m = ensure_numpy(mask).astype(np.float32) return (m > threshold).astype(np.uint8) def mean_dice_np(y_true, y_pred, smooth: float = 1e-3) -> np.ndarray: """ Compute Dice score per-sample for binary masks. Accepts shape (H, W) or (B, H, W) in numpy or torch. Returns scalar if no batch, or vector of length B if batched. """ y_true = ensure_numpy(y_true) y_pred = ensure_numpy(y_pred) if y_true.ndim == 2: y_true = y_true[None, ...] if y_pred.ndim == 2: y_pred = y_pred[None, ...] y_true = (np.abs(y_true) > 0).astype(np.float32) y_pred = (np.abs(y_pred) > 0).astype(np.float32) axes = (1, 2) intersection = np.sum(y_true * y_pred, axis=axes) mask_sum = np.sum(y_true, axis=axes) + np.sum(y_pred, axis=axes) dice = (2.0 * intersection + smooth) / (mask_sum + smooth) return dice if dice.size > 1 else dice.item() def mean_iou_np(y_true, y_pred, smooth: float = 1e-3) -> np.ndarray: """ Compute IoU per-sample for binary masks. Accepts shape (H, W) or (B, H, W) in numpy or torch. Returns scalar if no batch, or vector of length B if batched. """ y_true = ensure_numpy(y_true) y_pred = ensure_numpy(y_pred) if y_true.ndim == 2: y_true = y_true[None, ...] if y_pred.ndim == 2: y_pred = y_pred[None, ...] y_true = (np.abs(y_true) > 0).astype(np.float32) y_pred = (np.abs(y_pred) > 0).astype(np.float32) axes = (1, 2) intersection = np.sum(y_true * y_pred, axis=axes) union = np.sum((y_true + y_pred) > 0, axis=axes).astype(np.float32) iou = (intersection + smooth) / (union + smooth) return iou if iou.size > 1 else iou.item() def list_images(directory: str, exts: Tuple[str, ...] = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")) -> List[str]: files = [] for f in sorted(os.listdir(directory)): if os.path.splitext(f)[1].lower() in exts: files.append(os.path.join(directory, f)) return files def extract_surface(mask: np.ndarray) -> np.ndarray: m = (ensure_numpy(mask) > 0).astype(bool) if m.sum() == 0: return np.zeros_like(m, dtype=bool) eroded = binary_erosion(m) surface = m ^ eroded # Edge-case: if object is single pixel, erosion removes it; above handles via XOR return surface def nsd_np(y_true, y_pred, tolerance: float) -> np.ndarray: """ Normalized Surface Dice (NSD) at the given tolerance (in pixels). Returns per-sample NSD for (H,W) or (B,H,W). """ y_true = ensure_numpy(y_true) y_pred = ensure_numpy(y_pred) if y_true.ndim == 2: y_true = y_true[None, ...] if y_pred.ndim == 2: y_pred = y_pred[None, ...] nsd_vals: List[float] = [] for gt, pr in zip(y_true, y_pred): surf_gt = extract_surface(gt) surf_pr = extract_surface(pr) n_gt = int(surf_gt.sum()) n_pr = int(surf_pr.sum()) if n_gt == 0 and n_pr == 0: nsd_vals.append(1.0) continue # Distance maps to the other surface # distance_transform_edt computes distance to nearest zero; so use ~surface dt_to_pr = distance_transform_edt(~surf_pr) dt_to_gt = distance_transform_edt(~surf_gt) tp_gt = int((dt_to_pr[surf_gt] <= tolerance).sum()) if n_gt > 0 else 0 tp_pr = int((dt_to_gt[surf_pr] <= tolerance).sum()) if n_pr > 0 else 0 denom = float(n_gt + n_pr) if (n_gt + n_pr) > 0 else 1.0 nsd = (tp_gt + tp_pr) / denom nsd_vals.append(float(nsd)) return np.array(nsd_vals) if len(nsd_vals) > 1 else float(nsd_vals[0]) def load_image_as_3c(path: str) -> np.ndarray: img = io.imread(path) if img.ndim == 2: img = np.repeat(img[..., None], 3, axis=-1) elif img.ndim == 3 and img.shape[-1] == 4: img = img[..., :3] return img def load_mask_binary(path: str) -> np.ndarray: m = io.imread(path) if m.ndim == 3: # if RGB mask, convert to grayscale by any channel > 0 m = np.any(m > 0, axis=-1).astype(np.uint8) else: m = (m > 0).astype(np.uint8) return m def compute_bbox_from_mask(mask: np.ndarray) -> np.ndarray: ys, xs = np.where(mask > 0) if len(xs) == 0 or len(ys) == 0: return np.array([0, 0, mask.shape[1] - 1, mask.shape[0] - 1], dtype=np.int32) x1, x2 = xs.min(), xs.max() y1, y2 = ys.min(), ys.max() return np.array([x1, y1, x2, y2], dtype=np.int32) @torch.no_grad() def run_batch_inference(model, imgs_3c: List[np.ndarray], boxes_np: List[np.ndarray], device: str): """ imgs_3c: list of HxWx3 uint8/float images boxes_np: list of [x1,y1,x2,y2] in original image coordinates (per-image) Returns: list of predicted binary masks (numpy uint8, original HxW) """ B = len(imgs_3c) HWs = [(im.shape[0], im.shape[1]) for im in imgs_3c] # 新增:确保 numpy imgs_3c = [ensure_numpy(im) for im in imgs_3c] # Preprocess to 1024 and normalize to [0,1] imgs_1024 = [] boxes_1024 = [] for (H, W), img, box in zip(HWs, imgs_3c, boxes_np): img_1024 = transform.resize( img, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True ).astype(np.uint8) img_1024 = (img_1024 - img_1024.min()) / np.clip( img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None ) imgs_1024.append(img_1024) box_1024 = box / np.array([W, H, W, H], dtype=np.float32) * 1024.0 boxes_1024.append(box_1024.astype(np.float32)) img_tensor = torch.tensor(np.stack(imgs_1024, axis=0)).float().permute(0, 3, 1, 2).to(device) with torch.no_grad(): img_embed = model.image_encoder(img_tensor) # (B, 256, 64, 64) box_torch = torch.as_tensor(np.stack(boxes_1024, axis=0), dtype=torch.float32, device=device) box_torch = box_torch[:, None, :] # (B, 1, 4) sparse_embeddings, dense_embeddings = model.prompt_encoder( points=None, boxes=box_torch, masks=None, ) low_res_logits, _ = model.mask_decoder( image_embeddings=img_embed, image_pe=model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) low_res_prob = torch.sigmoid(low_res_logits) # (B,1,256,256) preds: List[np.ndarray] = [] for b in range(B): Hb, Wb = HWs[b] prob_b = F.interpolate(low_res_prob[b:b+1], size=(Hb, Wb), mode="bilinear", align_corners=False) mask_b = (prob_b.squeeze(0).squeeze(0) > 0.5).to(torch.uint8) preds.append(mask_b.detach().cpu().numpy()) return preds def main(): parser = argparse.ArgumentParser(description="Batched MedSAM inference and metrics") parser.add_argument("--images_dir", type=str, default="train_val_test/val/imgs", help="Directory with input images") parser.add_argument("--masks_dir", type=str, default="train_val_test/val/gts", help="Directory with GT masks") parser.add_argument("--output_dir", type=str, default="train_val_test/val/outputs", help="Where to save predictions and metrics") parser.add_argument("--bld_ckpt", type=str, default="work_dir/MedSAM/medsam_vit_b.pth", help="Model checkpoint path for building MedSAM") parser.add_argument("--evl_ckpt", type=str, default="work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth", help="Model checkpoint path for evaluation") parser.add_argument("--device", type=str, default="cuda:0", help="Device, e.g., cuda:0 or cpu") parser.add_argument("--batch_size", type=int, default=4, help="Batch size for image encoder") parser.add_argument("--nsd_tolerance", type=float, default=3.0, help="NSD tolerance in pixels") parser.add_argument("--save_preds", action="store_true", help="Save predicted masks") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = args.device from segment_anything import sam_model_registry, SamPredictor # 原版 SAM # 你 repo 自带 # ① 先建 SAM 骨干 sam = sam_model_registry["vit_b"](checkpoint=args.bld_ckpt) # 不加载,后面整体加载 # ② 包成 MedSAM(含 decoder) from train_one_gpu_ISIC2018 import MedSAM model = MedSAM( image_encoder=sam.image_encoder, mask_decoder=sam.mask_decoder, prompt_encoder=sam.prompt_encoder, ) # ③ 整体加载(含 decoder 字段) ckpt = torch.load('/root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth', map_location='cpu') # 若 ckpt 是 {"model": ..., "optimizer": ...} 先拆包 print("ckpt keys:", ckpt.keys()) if "model" in ckpt: ckpt = ckpt["model"] model.load_state_dict(ckpt, strict=False) # strict=False 允许字段超集 model = model.to(device) model.eval() img_paths = list_images(args.images_dir) if len(img_paths) == 0: raise FileNotFoundError(f"No images found in {args.images_dir}") # Assume masks share basename with images but may differ by extension def find_mask_for_image(img_path: str) -> str: """ Find corresponding mask for an image. Supports masks that append '_segmentation' to the image id before the extension, e.g.: ISIC_0000000.png -> ISIC_0000000_segmentation.png """ base = os.path.splitext(os.path.basename(img_path))[0] mask_exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff") # 1) Direct basename match for ext in mask_exts: cand = os.path.join(args.masks_dir, base + ext) if os.path.exists(cand): return cand # 2) Suffix '_segmentation' for ext in mask_exts: cand = os.path.join(args.masks_dir, f"{base}_segmentation{ext}") if os.path.exists(cand): return cand # 3) Fallback: scan and compare stem after removing '_segmentation' masks_dir = Path(args.masks_dir) for p in sorted(masks_dir.iterdir()): if p.is_file() and p.suffix.lower() in mask_exts: mask_stem = p.stem.replace('_segmentation', '') if mask_stem == base: return str(p) raise FileNotFoundError(f"No corresponding mask found for {img_path}") all_dice: List[float] = [] all_iou: List[float] = [] all_nsd: List[float] = [] batch_imgs: List[np.ndarray] = [] batch_boxes: List[np.ndarray] = [] batch_gts: List[np.ndarray] = [] batch_names: List[str] = [] #__________________ from torch.utils.data import Dataset, DataLoader # 你本地的 MedSAM 类 from train_one_gpu_ISIC2018 import NpyDataset # 你训练用的同一个 NpyDataset load_state='eval' # 1. 指向验证集(npy 目录) val_dataset = NpyDataset( data_root="/root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer", load_state=load_state # 新增参数,确保 eval 时不修改 numpy 数组 ) # bbox_shift=0, # 验证时不扰动框 # transform=None # 无增强 # 2. 批量读取(batch_size 就是 args.batch_size) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # 3. 按 batch 推理(复用你原来的 run_batch_inference) all_dice: List[float] = [] all_iou: List[float] = [] all_nsd: List[float] = [] for img_1024, gt2D, bboxes, img_names in val_loader: img_1024 = img_1024.to(device) gt2D = gt2D.to(device) bboxes = bboxes.to(device) # ---- 推理(复用你原来的函数)---- preds = run_batch_inference(model, img_1024, bboxes, device) # List[H,W] len = batch # ---- 指标计算(逐样本)---- gts_np = [binarize_mask(m) for m in gt2D.cpu().numpy()] # List[H,W] preds_np = [binarize_mask(p) for p in preds] # List[H,W] y_true = np.stack(gts_np, axis=0) # (B, H, W) y_pred = np.stack(preds_np, axis=0) # (B, H, W) dice_vals = mean_dice_np(y_true, y_pred) # 返回长度 B 的向量 iou_vals = mean_iou_np(y_true, y_pred) nsd_vals = nsd_np(y_true, y_pred, tolerance=args.nsd_tolerance) # 拆成单样本记录 all_dice.extend(dice_vals.tolist()) all_iou.extend(iou_vals.tolist()) all_nsd.extend(nsd_vals.tolist()) # ---- 可选:保存预测图 ---- if args.save_preds: for name, pred in zip(img_names, preds_np): out_path = os.path.join(args.output_dir, f"pred_{name}.png") io.imsave(out_path, (pred * 255).astype(np.uint8), check_contrast=False) # 4. 输出平均指标 mean_dice = float(np.mean(all_dice)) if len(all_dice) else 0.0 mean_iou = float(np.mean(all_iou)) if len(all_iou) else 0.0 mean_nsd = float(np.mean(all_nsd)) if len(all_nsd) else 0.0 with open(os.path.join(args.output_dir, "metrics.txt"), "w") as f: f.write(f"Num samples: {len(all_dice)}\n") f.write(f"Mean Dice: {mean_dice:.6f}\n") f.write(f"Mean IoU: {mean_iou:.6f}\n") f.write(f"Mean NSD@{args.nsd_tolerance}px: {mean_nsd:.6f}\n") print(f"Num samples: {len(all_dice)}") print(f"Mean Dice: {mean_dice:.6f}") print(f"Mean IoU: {mean_iou:.6f}") print(f"Mean NSD@{args.nsd_tolerance}px: {mean_nsd:.6f}") if __name__ == "__main__": main()
10-11
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

↣life♚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值