# -*- coding: utf-8 -*-
"""
Batch MedSAM/SAM box-prompt evaluation from CSV.
Each row provides: image_path,mask_path,x1,y1,x2,y2[,Dice,IoU,Pixel_Acc]
We compute Dice/IoU/PixelAcc and append them as the last 3 columns.
Usage:
python medsam_box_eval_from_csv.py \
--csv_in /path/to/list.csv \
--csv_out /path/to/list_with_metrics.csv \
--checkpoint /path/to/medsam_vit_b.pth \
--model_type vit_b \
--device cuda:0 \
--save_pred_dir /path/to/preds # optional
Notes:
- Checkpoint must be in SAM format (e.g., medsam_vit_b.pth).
- Header in CSV is optional. If present, include at least:
image_path,mask_path,x1,y1,x2,y2
"""
import argparse
import csv
import os
from pathlib import Path
from typing import List, Tuple
import numpy as np
from skimage import io
from skimage.color import gray2rgb
from segment_anything import sam_model_registry, SamPredictor
def to_uint8_rgb(img: np.ndarray) -> np.ndarray:
if img.ndim == 2:
img = gray2rgb(img)
elif img.ndim == 3 and img.shape[2] == 4:
img = img[..., :3]
elif img.ndim == 3 and img.shape[2] > 4:
img = img[..., :3]
if img.dtype != np.uint8:
img = img.astype(np.float32)
mn, mx = float(img.min()), float(img.max())
if mx > mn:
img = (255.0 * (img - mn) / (mx - mn)).clip(0, 255).astype(np.uint8)
else:
img = np.zeros_like(img, dtype=np.uint8)
return img
def binarize_mask(mask: np.ndarray) -> np.ndarray:
if mask.ndim == 3:
if mask.shape[2] > 1:
mask = (mask > 0).any(axis=2).astype(mask.dtype)
else:
mask = mask[..., 0]
return (mask > 0).astype(np.uint8)
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, float, float]:
assert y_true.shape == y_pred.shape, f"Shape mismatch: gt {y_true.shape} vs pred {y_pred.shape}"
y_true = (y_true > 0).astype(np.uint8)
y_pred = (y_pred > 0).astype(np.uint8)
tp = np.logical_and(y_true == 1, y_pred == 1).sum()
tn = np.logical_and(y_true == 0, y_pred == 0).sum()
fp = np.logical_and(y_true == 0, y_pred == 1).sum()
fn = np.logical_and(y_true == 1, y_pred == 0).sum()
dice = (2.0 * tp) / (2.0 * tp + fp + fn + 1e-8)
iou = tp / (tp + fp + fn + 1e-8)
pixel_acc = (tp + tn) / (tp + tn + fp + fn + 1e-8)
return float(dice), float(iou), float(pixel_acc)
def clip_box_to_image(box: np.ndarray, width: int, height: int) -> np.ndarray:
x0, y0, x1, y1 = [float(v) for v in box]
x0 = max(0, min(x0, width - 1))
x1 = max(0, min(x1, width - 1))
y0 = max(0, min(y0, height - 1))
y1 = max(0, min(y1, height - 1))
if x1 < x0:
x0, x1 = x1, x0
if y1 < y0:
y0, y1 = y1, y0
return np.array([x0, y0, x1, y1], dtype=np.float32)
def has_header(first_row: List[str]) -> bool:
names = [c.strip().lower() for c in first_row]
required = ["image_path", "mask_path", "x1", "y1", "x2", "y2"]
return all(col in names for col in required)
def main():
parser = argparse.ArgumentParser("Batch MedSAM/SAM evaluation from CSV")
parser.add_argument("--csv_in", type=str, required=True)
parser.add_argument("--csv_out", type=str, required=True)
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--model_type", type=str, default="vit_b", choices=["vit_b", "vit_l", "vit_h"])
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--save_pred_dir", type=str, default=None)
parser.add_argument("--multimask_output", action="store_true")
args = parser.parse_args()
if args.save_pred_dir:
os.makedirs(args.save_pred_dir, exist_ok=True)
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
sam.to(device=args.device)
predictor = SamPredictor(sam)
#__________________
from segment_anything import sam_model_registry, SamPredictor # 原版 SAM
# 你 repo 自带
# ① 先建 SAM 骨干
sam = sam_model_registry["vit_b"](checkpoint=args.bld_ckpt) # 不加载,后面整体加载
# ② 包成 MedSAM(含 decoder)
from train_one_gpu_ISIC2018 import MedSAM
model = MedSAM(
image_encoder=sam.image_encoder,
mask_decoder=sam.mask_decoder,
prompt_encoder=sam.prompt_encoder,
)
# ③ 整体加载(含 decoder 字段)
ckpt = torch.load('/root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth', map_location='cpu')
# 若 ckpt 是 {"model": ..., "optimizer": ...} 先拆包
print("ckpt keys:", ckpt.keys())
if "model" in ckpt:
ckpt = ckpt["model"]
model.load_state_dict(ckpt, strict=False) # strict=False 允许字段超集
model = model.to(device)
model.eval()
img_paths = list_images(args.images_dir)
if len(img_paths) == 0:
raise FileNotFoundError(f"No images found in {args.images_dir}")
# Assume masks share basename with images but may differ by extension
def find_mask_for_image(img_path: str) -> str:
"""
Find corresponding mask for an image. Supports masks that append
'_segmentation' to the image id before the extension, e.g.:
ISIC_0000000.png -> ISIC_0000000_segmentation.png
"""
base = os.path.splitext(os.path.basename(img_path))[0]
mask_exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
# 1) Direct basename match
for ext in mask_exts:
cand = os.path.join(args.masks_dir, base + ext)
if os.path.exists(cand):
return cand
# 2) Suffix '_segmentation'
for ext in mask_exts:
cand = os.path.join(args.masks_dir, f"{base}_segmentation{ext}")
if os.path.exists(cand):
return cand
# 3) Fallback: scan and compare stem after removing '_segmentation'
masks_dir = Path(args.masks_dir)
for p in sorted(masks_dir.iterdir()):
if p.is_file() and p.suffix.lower() in mask_exts:
mask_stem = p.stem.replace('_segmentation', '')
if mask_stem == base:
return str(p)
raise FileNotFoundError(f"No corresponding mask found for {img_path}")
all_dice: List[float] = []
all_iou: List[float] = []
all_nsd: List[float] = []
batch_imgs: List[np.ndarray] = []
batch_boxes: List[np.ndarray] = []
batch_gts: List[np.ndarray] = []
batch_names: List[str] = []
from torch.utils.data import Dataset, DataLoader
# 你本地的 MedSAM 类
from train_one_gpu_ISIC2018 import NpyDataset # 你训练用的同一个 NpyDataset
load_state='eval'
# 1. 指向验证集(npy 目录)
val_dataset = NpyDataset(
data_root="/root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer",
load_state=load_state # 新增参数,确保 eval 时不修改 numpy 数组
)
# bbox_shift=0, # 验证时不扰动框
# transform=None # 无增强
# 2. 批量读取(batch_size 就是 args.batch_size)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
# 3. 按 batch 推理(复用你原来的 run_batch_inference)
all_dice: List[float] = []
all_iou: List[float] = []
all_nsd: List[float] = []
for img_1024, gt2D, bboxes, img_names in val_loader:
img_1024 = img_1024.to(device)
gt2D = gt2D.to(device)
bboxes = bboxes.to(device)
# ---- 推理(复用你原来的函数)----
preds = run_batch_inference(model, img_1024, bboxes, device) # List[H,W] len = batch
#________________________________________________
with open(args.csv_in, "r", newline="") as f:
reader = csv.reader(f)
rows = list(reader)
if not rows:
raise RuntimeError("Empty CSV.")
header_present = has_header(rows[0])
if header_present:
header = [c.strip() for c in rows[0]]
data_rows = rows[1:]
else:
header = ["image_path", "mask_path", "x1", "y1", "x2", "y2"]
data_rows = rows
metric_cols = ["Dice", "IoU", "Pixel_Acc"]
for col in metric_cols:
if col not in header:
header.append(col)
out_rows = [header]
for r in data_rows:
if len(r) < 6:
continue
image_path = r[0].strip()
mask_path = r[1].strip()
try:
x1 = float(r[2]); y1 = float(r[3]); x2 = float(r[4]); y2 = float(r[5])
except Exception:
continue
if not os.path.isfile(image_path):
print(f"[WARN] image not found: {image_path}")
out_rows.append(r[:6] + ["nan","nan","nan"])
continue
if not os.path.isfile(mask_path):
print(f"[WARN] mask not found: {mask_path}")
out_rows.append(r[:6] + ["nan","nan","nan"])
continue
img = io.imread(image_path); msk = io.imread(mask_path)
img = to_uint8_rgb(img); msk = binarize_mask(msk)
H, W = img.shape[:2]
if msk.shape[:2] != (H, W):
print(f"[WARN] shape mismatch (img {img.shape[:2]} vs mask {msk.shape[:2]}). Skipped: {image_path}")
out_rows.append(r[:6] + ["nan","nan","nan"])
continue
predictor.set_image(img)
box = clip_box_to_image(np.array([x1, y1, x2, y2], dtype=np.float32), width=W, height=H)
masks, scores, _ = predictor.predict(box=box, multimask_output=args.multimask_output)
if masks.ndim == 3:
best = int(np.argmax(scores)) if args.multimask_output else 0
pred = masks[best]
else:
pred = masks
pred_bin = (pred > 0).astype(np.uint8)
dice, iou, pixel_acc = compute_metrics(msk, pred_bin)
if args.save_pred_dir:
stem = Path(image_path).stem
out_mask = Path(args.save_pred_dir) / f"{stem}.png"
io.imsave(str(out_mask), (pred_bin * 255).astype(np.uint8), check_contrast=False)
out_rows.append(r[:6] + [f"{dice:.6f}", f"{iou:.6f}", f"{pixel_acc:.6f}"])
with open(args.csv_out, "w", newline="") as f:
csv.writer(f).writerows(out_rows)
print(f"Done. Wrote {len(out_rows)-1} rows to: {args.csv_out}")
if args.save_pred_dir:
print(f"Pred masks saved to: {args.save_pred_dir}")
if __name__ == "__main__":
main() 这是准备运行的一个推理文件,目的是先用csv文件里的box提示来让模型给出mask,然后再批量地把mask和给出的标签进行dice等指标的计算,其中两个#________________________________之间的部分是从另一个文件加入的代码,目的是先用基础的MedSAM_vit-b的权重构建一个sam,然后在此基础上,加载在当前数据集上全量微调过的MedSAM,以此测试微调的结果,另一个文件的代码如下:# -*- coding: utf-8 -*-
"""
Run batched inference with MedSAM and compute dataset-level Dice and IoU.
#datasets/val/npy/
Example:
python MedSAM_Inference_v2.py \
--images_dir train_val_test/val/imgs \
--masks_dir train_val_test/val/gts \
--output_dir train_val_test/val/output \
--checkpoint /home/lin/Medical_Img_Segment/MedSAM_PEFT_v1/work_dir/MedSAM/medsam_vit_b.pth \
--batch_size 4 --device cuda:0 --save_preds
python /root/autodl-tmp/MedSAM_PEFT/MedSAM_Inference_v2.py \
--images_dir /root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer/imgs \
--masks_dir /root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer/gts \
--output_dir train_val_test/val/output \
--bld_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/medsam_vit_b.pth \
--evl_ckpt /root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth \
--batch_size 8 --device cuda:0 \
# --save_preds /root/autodl-tmp/MedSAM_PEFT/train_val_test/val/save_preds
grep -n "add_argument.*--save_preds" MedSAM_Inference_v2.py
/root/autodl-tmp/MedSAM_PEFT/MedSAM_Inference_v2.py
python MedSAM_Inference_v2.py --help
"""
import os
import argparse
from typing import List, Tuple
from pathlib import Path
import numpy as np
from skimage import io, transform
from skimage.morphology import binary_erosion
from scipy.ndimage import distance_transform_edt
import torch
import torch.nn.functional as F
from segment_anything import sam_model_registry
def ensure_numpy(array_like):
"""
Convert input to numpy array on CPU without modifying values.
Supports numpy arrays and torch tensors.
"""
if isinstance(array_like, np.ndarray):
return array_like
if isinstance(array_like, torch.Tensor):
return array_like.detach().cpu().numpy()
raise TypeError(f"Unsupported type: {type(array_like)}")
def binarize_mask(mask: np.ndarray, threshold: float = 0.5) -> np.ndarray:
m = ensure_numpy(mask).astype(np.float32)
return (m > threshold).astype(np.uint8)
def mean_dice_np(y_true, y_pred, smooth: float = 1e-3) -> np.ndarray:
"""
Compute Dice score per-sample for binary masks.
Accepts shape (H, W) or (B, H, W) in numpy or torch. Returns scalar if no batch,
or vector of length B if batched.
"""
y_true = ensure_numpy(y_true)
y_pred = ensure_numpy(y_pred)
if y_true.ndim == 2:
y_true = y_true[None, ...]
if y_pred.ndim == 2:
y_pred = y_pred[None, ...]
y_true = (np.abs(y_true) > 0).astype(np.float32)
y_pred = (np.abs(y_pred) > 0).astype(np.float32)
axes = (1, 2)
intersection = np.sum(y_true * y_pred, axis=axes)
mask_sum = np.sum(y_true, axis=axes) + np.sum(y_pred, axis=axes)
dice = (2.0 * intersection + smooth) / (mask_sum + smooth)
return dice if dice.size > 1 else dice.item()
def mean_iou_np(y_true, y_pred, smooth: float = 1e-3) -> np.ndarray:
"""
Compute IoU per-sample for binary masks.
Accepts shape (H, W) or (B, H, W) in numpy or torch. Returns scalar if no batch,
or vector of length B if batched.
"""
y_true = ensure_numpy(y_true)
y_pred = ensure_numpy(y_pred)
if y_true.ndim == 2:
y_true = y_true[None, ...]
if y_pred.ndim == 2:
y_pred = y_pred[None, ...]
y_true = (np.abs(y_true) > 0).astype(np.float32)
y_pred = (np.abs(y_pred) > 0).astype(np.float32)
axes = (1, 2)
intersection = np.sum(y_true * y_pred, axis=axes)
union = np.sum((y_true + y_pred) > 0, axis=axes).astype(np.float32)
iou = (intersection + smooth) / (union + smooth)
return iou if iou.size > 1 else iou.item()
def list_images(directory: str, exts: Tuple[str, ...] = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")) -> List[str]:
files = []
for f in sorted(os.listdir(directory)):
if os.path.splitext(f)[1].lower() in exts:
files.append(os.path.join(directory, f))
return files
def extract_surface(mask: np.ndarray) -> np.ndarray:
m = (ensure_numpy(mask) > 0).astype(bool)
if m.sum() == 0:
return np.zeros_like(m, dtype=bool)
eroded = binary_erosion(m)
surface = m ^ eroded
# Edge-case: if object is single pixel, erosion removes it; above handles via XOR
return surface
def nsd_np(y_true, y_pred, tolerance: float) -> np.ndarray:
"""
Normalized Surface Dice (NSD) at the given tolerance (in pixels).
Returns per-sample NSD for (H,W) or (B,H,W).
"""
y_true = ensure_numpy(y_true)
y_pred = ensure_numpy(y_pred)
if y_true.ndim == 2:
y_true = y_true[None, ...]
if y_pred.ndim == 2:
y_pred = y_pred[None, ...]
nsd_vals: List[float] = []
for gt, pr in zip(y_true, y_pred):
surf_gt = extract_surface(gt)
surf_pr = extract_surface(pr)
n_gt = int(surf_gt.sum())
n_pr = int(surf_pr.sum())
if n_gt == 0 and n_pr == 0:
nsd_vals.append(1.0)
continue
# Distance maps to the other surface
# distance_transform_edt computes distance to nearest zero; so use ~surface
dt_to_pr = distance_transform_edt(~surf_pr)
dt_to_gt = distance_transform_edt(~surf_gt)
tp_gt = int((dt_to_pr[surf_gt] <= tolerance).sum()) if n_gt > 0 else 0
tp_pr = int((dt_to_gt[surf_pr] <= tolerance).sum()) if n_pr > 0 else 0
denom = float(n_gt + n_pr) if (n_gt + n_pr) > 0 else 1.0
nsd = (tp_gt + tp_pr) / denom
nsd_vals.append(float(nsd))
return np.array(nsd_vals) if len(nsd_vals) > 1 else float(nsd_vals[0])
def load_image_as_3c(path: str) -> np.ndarray:
img = io.imread(path)
if img.ndim == 2:
img = np.repeat(img[..., None], 3, axis=-1)
elif img.ndim == 3 and img.shape[-1] == 4:
img = img[..., :3]
return img
def load_mask_binary(path: str) -> np.ndarray:
m = io.imread(path)
if m.ndim == 3:
# if RGB mask, convert to grayscale by any channel > 0
m = np.any(m > 0, axis=-1).astype(np.uint8)
else:
m = (m > 0).astype(np.uint8)
return m
def compute_bbox_from_mask(mask: np.ndarray) -> np.ndarray:
ys, xs = np.where(mask > 0)
if len(xs) == 0 or len(ys) == 0:
return np.array([0, 0, mask.shape[1] - 1, mask.shape[0] - 1], dtype=np.int32)
x1, x2 = xs.min(), xs.max()
y1, y2 = ys.min(), ys.max()
return np.array([x1, y1, x2, y2], dtype=np.int32)
@torch.no_grad()
def run_batch_inference(model, imgs_3c: List[np.ndarray], boxes_np: List[np.ndarray], device: str):
"""
imgs_3c: list of HxWx3 uint8/float images
boxes_np: list of [x1,y1,x2,y2] in original image coordinates (per-image)
Returns: list of predicted binary masks (numpy uint8, original HxW)
"""
B = len(imgs_3c)
HWs = [(im.shape[0], im.shape[1]) for im in imgs_3c]
# 新增:确保 numpy
imgs_3c = [ensure_numpy(im) for im in imgs_3c]
# Preprocess to 1024 and normalize to [0,1]
imgs_1024 = []
boxes_1024 = []
for (H, W), img, box in zip(HWs, imgs_3c, boxes_np):
img_1024 = transform.resize(
img, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
).astype(np.uint8)
img_1024 = (img_1024 - img_1024.min()) / np.clip(
img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
)
imgs_1024.append(img_1024)
box_1024 = box / np.array([W, H, W, H], dtype=np.float32) * 1024.0
boxes_1024.append(box_1024.astype(np.float32))
img_tensor = torch.tensor(np.stack(imgs_1024, axis=0)).float().permute(0, 3, 1, 2).to(device)
with torch.no_grad():
img_embed = model.image_encoder(img_tensor) # (B, 256, 64, 64)
box_torch = torch.as_tensor(np.stack(boxes_1024, axis=0), dtype=torch.float32, device=device)
box_torch = box_torch[:, None, :] # (B, 1, 4)
sparse_embeddings, dense_embeddings = model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = model.mask_decoder(
image_embeddings=img_embed,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_prob = torch.sigmoid(low_res_logits) # (B,1,256,256)
preds: List[np.ndarray] = []
for b in range(B):
Hb, Wb = HWs[b]
prob_b = F.interpolate(low_res_prob[b:b+1], size=(Hb, Wb), mode="bilinear", align_corners=False)
mask_b = (prob_b.squeeze(0).squeeze(0) > 0.5).to(torch.uint8)
preds.append(mask_b.detach().cpu().numpy())
return preds
def main():
parser = argparse.ArgumentParser(description="Batched MedSAM inference and metrics")
parser.add_argument("--images_dir", type=str, default="train_val_test/val/imgs", help="Directory with input images")
parser.add_argument("--masks_dir", type=str, default="train_val_test/val/gts", help="Directory with GT masks")
parser.add_argument("--output_dir", type=str, default="train_val_test/val/outputs", help="Where to save predictions and metrics")
parser.add_argument("--bld_ckpt", type=str, default="work_dir/MedSAM/medsam_vit_b.pth", help="Model checkpoint path for building MedSAM")
parser.add_argument("--evl_ckpt", type=str, default="work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth", help="Model checkpoint path for evaluation")
parser.add_argument("--device", type=str, default="cuda:0", help="Device, e.g., cuda:0 or cpu")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for image encoder")
parser.add_argument("--nsd_tolerance", type=float, default=3.0, help="NSD tolerance in pixels")
parser.add_argument("--save_preds", action="store_true", help="Save predicted masks")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = args.device
from segment_anything import sam_model_registry, SamPredictor # 原版 SAM
# 你 repo 自带
# ① 先建 SAM 骨干
sam = sam_model_registry["vit_b"](checkpoint=args.bld_ckpt) # 不加载,后面整体加载
# ② 包成 MedSAM(含 decoder)
from train_one_gpu_ISIC2018 import MedSAM
model = MedSAM(
image_encoder=sam.image_encoder,
mask_decoder=sam.mask_decoder,
prompt_encoder=sam.prompt_encoder,
)
# ③ 整体加载(含 decoder 字段)
ckpt = torch.load('/root/autodl-tmp/MedSAM_PEFT/work_dir/MedSAM/work_dir/MedSAM-ViT-B-ISIC2018-20251009-0021/medsam_model_best.pth', map_location='cpu')
# 若 ckpt 是 {"model": ..., "optimizer": ...} 先拆包
print("ckpt keys:", ckpt.keys())
if "model" in ckpt:
ckpt = ckpt["model"]
model.load_state_dict(ckpt, strict=False) # strict=False 允许字段超集
model = model.to(device)
model.eval()
img_paths = list_images(args.images_dir)
if len(img_paths) == 0:
raise FileNotFoundError(f"No images found in {args.images_dir}")
# Assume masks share basename with images but may differ by extension
def find_mask_for_image(img_path: str) -> str:
"""
Find corresponding mask for an image. Supports masks that append
'_segmentation' to the image id before the extension, e.g.:
ISIC_0000000.png -> ISIC_0000000_segmentation.png
"""
base = os.path.splitext(os.path.basename(img_path))[0]
mask_exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
# 1) Direct basename match
for ext in mask_exts:
cand = os.path.join(args.masks_dir, base + ext)
if os.path.exists(cand):
return cand
# 2) Suffix '_segmentation'
for ext in mask_exts:
cand = os.path.join(args.masks_dir, f"{base}_segmentation{ext}")
if os.path.exists(cand):
return cand
# 3) Fallback: scan and compare stem after removing '_segmentation'
masks_dir = Path(args.masks_dir)
for p in sorted(masks_dir.iterdir()):
if p.is_file() and p.suffix.lower() in mask_exts:
mask_stem = p.stem.replace('_segmentation', '')
if mask_stem == base:
return str(p)
raise FileNotFoundError(f"No corresponding mask found for {img_path}")
all_dice: List[float] = []
all_iou: List[float] = []
all_nsd: List[float] = []
batch_imgs: List[np.ndarray] = []
batch_boxes: List[np.ndarray] = []
batch_gts: List[np.ndarray] = []
batch_names: List[str] = []
#__________________
from torch.utils.data import Dataset, DataLoader
# 你本地的 MedSAM 类
from train_one_gpu_ISIC2018 import NpyDataset # 你训练用的同一个 NpyDataset
load_state='eval'
# 1. 指向验证集(npy 目录)
val_dataset = NpyDataset(
data_root="/root/autodl-tmp/MedSAM_PEFT/datasets/ISIC2018/val/npy/Derm_SkinCancer",
load_state=load_state # 新增参数,确保 eval 时不修改 numpy 数组
)
# bbox_shift=0, # 验证时不扰动框
# transform=None # 无增强
# 2. 批量读取(batch_size 就是 args.batch_size)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
# 3. 按 batch 推理(复用你原来的 run_batch_inference)
all_dice: List[float] = []
all_iou: List[float] = []
all_nsd: List[float] = []
for img_1024, gt2D, bboxes, img_names in val_loader:
img_1024 = img_1024.to(device)
gt2D = gt2D.to(device)
bboxes = bboxes.to(device)
# ---- 推理(复用你原来的函数)----
preds = run_batch_inference(model, img_1024, bboxes, device) # List[H,W] len = batch
# ---- 指标计算(逐样本)----
gts_np = [binarize_mask(m) for m in gt2D.cpu().numpy()] # List[H,W]
preds_np = [binarize_mask(p) for p in preds] # List[H,W]
y_true = np.stack(gts_np, axis=0) # (B, H, W)
y_pred = np.stack(preds_np, axis=0) # (B, H, W)
dice_vals = mean_dice_np(y_true, y_pred) # 返回长度 B 的向量
iou_vals = mean_iou_np(y_true, y_pred)
nsd_vals = nsd_np(y_true, y_pred, tolerance=args.nsd_tolerance)
# 拆成单样本记录
all_dice.extend(dice_vals.tolist())
all_iou.extend(iou_vals.tolist())
all_nsd.extend(nsd_vals.tolist())
# ---- 可选:保存预测图 ----
if args.save_preds:
for name, pred in zip(img_names, preds_np):
out_path = os.path.join(args.output_dir, f"pred_{name}.png")
io.imsave(out_path, (pred * 255).astype(np.uint8), check_contrast=False)
# 4. 输出平均指标
mean_dice = float(np.mean(all_dice)) if len(all_dice) else 0.0
mean_iou = float(np.mean(all_iou)) if len(all_iou) else 0.0
mean_nsd = float(np.mean(all_nsd)) if len(all_nsd) else 0.0
with open(os.path.join(args.output_dir, "metrics.txt"), "w") as f:
f.write(f"Num samples: {len(all_dice)}\n")
f.write(f"Mean Dice: {mean_dice:.6f}\n")
f.write(f"Mean IoU: {mean_iou:.6f}\n")
f.write(f"Mean NSD@{args.nsd_tolerance}px: {mean_nsd:.6f}\n")
print(f"Num samples: {len(all_dice)}")
print(f"Mean Dice: {mean_dice:.6f}")
print(f"Mean IoU: {mean_iou:.6f}")
print(f"Mean NSD@{args.nsd_tolerance}px: {mean_nsd:.6f}")
if __name__ == "__main__":
main()