# backend/run_clam.py
import os
import shutil
import subprocess
import textwrap
from typing import Dict, Tuple, Callable, Optional
import numpy as np
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
#pointing out the location where we run the model and save the data
VAL_ROOT = "/commondocument/group10/wuqiuyang/Validation_on_CLAM2"
CLAM_ROOT = "/commondocument/group10/tom/newclam/CLAM"
#model to be used, under the CLAM root, preset .csv
CKPT_PATH = os.path.join(VAL_ROOT, "trained_model", "s_full_data_checkpoint.pt")
#when creating heatmaps, seems that we need this preset.csv file to guarantee the results are correctly outputed.
PRESET_PATH = os.path.join(VAL_ROOT, "presets", "bwh_biopsy.csv")
# setting the mirrors for hugging-face to run extract_features.py
HF_ENV = {
"HF_ENDPOINT": "https://hf-mirror.com",
"HF_HUB_ENDPOINT": "https://hf-mirror.com",
"HF_HUB_ETAG_TIMEOUT": "120",
"HF_HUB_DOWNLOAD_TIMEOUT": "600",
}
def _ensure_basic_files():
"""Check CKPT_path and PRESET_path are present"""
if not os.path.exists(CKPT_PATH):
raise FileNotFoundError(f"找不到 CLAM checkpoint:{CKPT_PATH}")
if not os.path.exists(PRESET_PATH):
raise FileNotFoundError(f"找不到 preset 文件:{PRESET_PATH}")
def _write_yaml_config(run_root: str, slide_id: str, slide_ext: str) -> str:
"""在 run_root 里写一份只针对这一次推理的 heatmap 配置."""
test_data_dir = os.path.join(run_root, "test_data")
yaml_path = os.path.join(run_root, "my_heatmap.yaml")
yaml_text = f"""
exp_arguments:
n_classes: 2
save_exp_code: HEATMAP_OUTPUT
raw_save_dir: heatmaps/heatmap_raw_results
production_save_dir: heatmaps/heatmap_production_results
batch_size: 256
data_arguments:
data_dir: {test_data_dir}
data_dir_key: source
process_list: my_slides_for_heatmap.csv
preset: {PRESET_PATH}
slide_ext: {slide_ext}
label_dict:
normal_tissue: 0
tumor_tissue: 1
patching_arguments:
patch_size: 256
overlap: 0.5
patch_level: 0
custom_downsample: 1
encoder_arguments:
model_name: resnet50_trunc
target_img_size: 224
model_arguments:
ckpt_path: {CKPT_PATH}
model_type: clam_sb
initiate_fn: initiate_model
model_size: small
drop_out: 0.
embed_dim: 1024
heatmap_arguments:
vis_level: 3
alpha: 0.4
blank_canvas: false
save_orig: false
save_ext: jpg
use_ref_scores: false
blur: false
use_center_shift: true
use_roi: false
calc_heatmap: true
binarize: false
binary_thresh: -1
custom_downsample: 10
cmap: jet
sample_arguments:
samples:
- name: "topk_high_attention"
sample: true
seed: 1
k: 15
mode: topk
"""
yaml_text = textwrap.dedent(yaml_text).strip() + "\n"
with open(yaml_path, "w", encoding="utf-8") as f:
f.write(yaml_text)
return yaml_path
#here we chose to downsample the file so that the figures can be shown on the website
def _find_heatmap_image(run_root: str, slide_id: str) -> str:
"""
在生产结果目录里找当前 slide 的 jpg 热图:
1)只考虑文件名包含 slide_id 的 jpg
2)优先在“非 orig” 文件中选体积最小的一张(overlay 图)
3)如果没有非 orig,则在“orig” 文件中选体积最小的一张
"""
prod_root = os.path.join(run_root, "heatmaps", "heatmap_production_results")
if not os.path.isdir(prod_root):
raise FileNotFoundError(f"找不到热图输出目录: {prod_root}")
smallest_non_orig = None # (size_bytes, path)
smallest_orig = None # (size_bytes, path)
for root, _, files in os.walk(prod_root):
for name in files:
lower = name.lower()
if not lower.endswith(".jpg"):
continue
if slide_id not in name:
continue
full_path = os.path.join(root, name)
try:
size_bytes = os.path.getsize(full_path)
except OSError:
continue
if "orig" in lower:
if smallest_orig is None or size_bytes < smallest_orig[0]:
smallest_orig = (size_bytes, full_path)
else:
if smallest_non_orig is None or size_bytes < smallest_non_orig[0]:
smallest_non_orig = (size_bytes, full_path)
if smallest_non_orig is not None:
return smallest_non_orig[1]
if smallest_orig is not None:
return smallest_orig[1]
raise FileNotFoundError(f"在 {prod_root} 中没找到 {slide_id} 对应的 jpg 热图文件")
def _read_clam_result(run_root: str, slide_id: str) -> Dict:
"""从 heatmaps/results 中读出本次 slide 的结构化预测结果."""
results_dir = os.path.join(run_root, "heatmaps", "results")
if not os.path.isdir(results_dir):
raise FileNotFoundError(f"找不到结果目录: {results_dir}")
result_csv = os.path.join(results_dir, "my_slides_for_heatmap.csv")
if not os.path.exists(result_csv):
for name in os.listdir(results_dir):
if name.lower().endswith(".csv"):
result_csv = os.path.join(results_dir, name)
break
if not os.path.exists(result_csv):
raise FileNotFoundError(f"在 {results_dir} 中没找到任何 csv 结果文件")
df = pd.read_csv(result_csv)
if "slide_id" in df.columns:
sub = df[df["slide_id"] == slide_id]
if len(sub) == 0:
row = df.iloc[0]
else:
row = sub.iloc[0]
else:
row = df.iloc[0]
result = {}
for k, v in row.to_dict().items():
if isinstance(v, np.generic):
v = v.item()
result[k] = v
return result
def run_clam_inference(
tiff_path: str,
progress_cb: Optional[Callable[[str, float], None]] = None,
) -> Tuple[Dict, str]:
"""
给定一张 .tif/.tiff 切片路径:
1)复制到 Validation_on_CLAM2/web_runs/{slide_id}/test_data
2)运行 create_patches_fp.py
3)运行 extract_features_fp.py
4)生成 my_slides_for_heatmap.csv + my_heatmap.yaml
5)运行 create_heatmaps.py(自动回答 Y)
6)返回: (clam_result_dict, heatmap_image_path)
progress_cb: 可选回调,形如 progress_cb(message: str, progress: float)
progress 取值 [0,1],用于前端展示进度条。
"""
def update(msg: str, p: float):
if progress_cb is not None:
progress_cb(msg, p)
_ensure_basic_files()
if not os.path.exists(tiff_path):
raise FileNotFoundError(f"找不到输入切片:{tiff_path}")
slide_name = os.path.basename(tiff_path)
slide_id, ext = os.path.splitext(slide_name)
ext = ext.lower()
if ext not in [".tif", ".tiff"]:
raise ValueError(f"仅支持 .tif/.tiff,当前是 {ext}")
update("初始化环境与目录…", 0.02)
# 每张 slide 一个独立的 run_root,方便日后多例管理
run_root = os.path.join(VAL_ROOT, "web_runs", slide_id)
test_data_dir = os.path.join(run_root, "test_data")
clam_output_dir = os.path.join(run_root, "clam_output")
os.makedirs(test_data_dir, exist_ok=True)
os.makedirs(clam_output_dir, exist_ok=True)
os.makedirs(os.path.join(clam_output_dir, "features"), exist_ok=True)
# heatmaps 子目录
os.makedirs(os.path.join(run_root, "heatmaps", "heatmap_raw_results"), exist_ok=True)
os.makedirs(os.path.join(run_root, "heatmaps", "heatmap_production_results"), exist_ok=True)
os.makedirs(os.path.join(run_root, "heatmaps", "results"), exist_ok=True)
os.makedirs(os.path.join(run_root, "heatmaps", "process_lists"), exist_ok=True)
# 复制切片到专用 test_data 目录
dest_slide_path = os.path.join(test_data_dir, slide_id + ext)
shutil.copy2(tiff_path, dest_slide_path)
# 统一 env
env = os.environ.copy()
env.update(HF_ENV)
# ------- Step 1: 切 patch -------
update("Step 1/3:正在切 patch …", 0.10)
cmd_patches = [
"python",
os.path.join(CLAM_ROOT, "create_patches_fp.py"),
"--source", test_data_dir,
"--save_dir", clam_output_dir,
"--patch_size", "256",
"--seg", "--patch", "--stitch",
]
subprocess.run(cmd_patches, check=True, cwd=run_root, env=env)
update("Step 1/3 完成:patch 已生成", 0.33)
# ------- Step 2: 抽特征 -------
update("Step 2/3:正在抽取特征 …", 0.35)
csv_path = os.path.join(clam_output_dir, "process_list_autogen.csv")
cmd_feats = [
"python",
os.path.join(CLAM_ROOT, "extract_features_fp.py"),
"--data_h5_dir", clam_output_dir,
"--data_slide_dir", test_data_dir,
"--csv_path", csv_path,
"--feat_dir", os.path.join(clam_output_dir, "features"),
"--batch_size", "512",
"--slide_ext", ext,
]
subprocess.run(cmd_feats, check=True, cwd=run_root, env=env)
update("Step 2/3 完成:特征已抽取", 0.66)
# ------- Step 3: 生成 process_list CSV -------
update("Step 3/3:准备生成热图 …", 0.70)
process_list_path = os.path.join(run_root, "heatmaps", "process_lists", "my_slides_for_heatmap.csv")
with open(process_list_path, "w", encoding="utf-8") as f:
f.write("slide_id,label\n")
f.write(f"{slide_id},tumor_tissue\n")
# ------- Step 4: 写 heatmap 配置 -------
config_path = _write_yaml_config(run_root=run_root, slide_id=slide_id, slide_ext=ext)
# ------- Step 5: 生成热图(自动回答 Y)-------
cmd_heat = [
"python",
os.path.join(CLAM_ROOT, "create_heatmaps.py"),
"--config", config_path,
]
update("Step 3/3:正在生成热图 …", 0.80)
subprocess.run(
cmd_heat,
check=True,
cwd=run_root,
env=env,
input="Y\n",
text=True,
)
update("Step 3/3 完成:热图已生成", 0.95)
# ------- Step 6: 收集结果 -------
heatmap_img = _find_heatmap_image(run_root, slide_id)
clam_result = _read_clam_result(run_root, slide_id)
clam_result["run_root"] = run_root
update("全部流程完成 ✅", 1.0)
return clam_result, heatmap_img
帮我逐句解释以下这段代码
最新发布