MONAI对病理影像大图(WSI)肿瘤检测流程,覆盖:
-
WSI Patch 采样与数据集构建(用
WSIPatchDataset
) -
模型训练(以 UNet 为例)
-
滑动窗口推理(
sliding_window_inference
)全图分割 -
可视化和全图掩膜生成
重点使用 MONAI 的功能,兼容 OpenSlide 超大病理图像。代码默认已准备好数据(WSI 图片和掩膜),适合你快速二次开发。
1. 环境准备
pip install monai[all] openslide-python torch torchvision matplotlib
2. Patch 数据集构建(WSIPatchDataset)
假设数据结构如下:
dataset/
images/
case1.svs
case2.svs
...
masks/
case1_mask.tif
case2_mask.tif
...
Patch 数据集实现
from monai.data import WSIPatchDataset
from monai.transforms import (
LoadImaged, AddChanneld, ScaleIntensityd, ToTensord, Compose
)
import glob
# 获取 WSI 和掩膜路径
image_paths = sorted(glob.glob('dataset/images/*.svs'))
mask_paths = sorted(glob.glob('dataset/masks/*.tif'))
# 构建数据对
data_dicts = [{"image": img, "mask": msk} for img, msk in zip(image_paths, mask_paths)]
# Patch 配置
patch_size = (256, 256)
level = 0 # 金字塔层级,通常0为最高分辨率
patch_per_wsi = 1000 # 每张WSI采样patch数量
# 数据增强
train_transforms = Compose([
LoadImaged(keys=["image", "mask"], reader='WSIReader'),
AddChanneld(keys=["image", "mask"]),
ScaleIntensityd(keys=["image"]),
ToTensord(keys=["image", "mask"])
])
# 构建 WSIPatchDataset
train_ds = WSIPatchDataset(
data=data_dicts,
patch_size=patch_size,
transform=train_transforms,
level=level,
patch_per_wsi=patch_per_wsi,
return_patch_iter=False, # 直接返回patch张量和标签
)
from torch.utils.data import DataLoader
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
3. 定义分割模型
from monai.networks.nets import UNet
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = UNet(
spatial_dims=2, # 2D patch
in_channels=1, # 通道数可根据你的数据调整
out_channels=2, # 例如肿瘤/背景二分类
channels=(32, 64, 128, 256, 512),
strides=(2, 2, 2, 2),
num_res_units=2
).to(device)
4. 训练模型
from monai.losses import DiceLoss
import torch.optim as optim
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = optim.Adam(net.parameters(), lr=1e-4)
max_epochs = 20
for epoch in range(max_epochs):
net.train()
epoch_loss = 0
for batch_data in train_loader:
inputs = batch_data["image"].to(device)
labels = batch_data["mask"].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch {epoch+1}/{max_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
5. 滑动窗口推理(全图掩膜重建)
推理时按窗口滑动整张WSI,拼接mask:
from monai.inferers import sliding_window_inference
from monai.data import WSIReader
import numpy as np
# 假设要对 case1.svs 推理
wsi_path = 'dataset/images/case1.svs'
# 1. 读取整张 WSI
reader = WSIReader()
wsi_img = reader.read(wsi_path, level=level)
wsi_arr = np.array(wsi_img) # shape: (H, W, C)
if wsi_arr.ndim == 3 and wsi_arr.shape[2] > 1:
wsi_arr = wsi_arr[..., 0] # 若多通道,仅用一通道或做通道合成
wsi_tensor = torch.tensor(wsi_arr).unsqueeze(0).unsqueeze(0).float().to(device) # shape: [1,1,H,W]
# 2. 滑动窗口推理
roi_size = (256, 256)
sw_batch_size = 4
net.eval()
with torch.no_grad():
pred_mask = sliding_window_inference(
wsi_tensor, roi_size, sw_batch_size, net, overlap=0.25
)
pred_mask = torch.argmax(pred_mask, dim=1).cpu().numpy()[0] # [H,W]
6. 可视化分割结果
import matplotlib.pyplot as plt
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.title('Original WSI')
plt.imshow(wsi_arr, cmap='gray')
plt.subplot(1,2,2)
plt.title('Predicted Tumor Mask')
plt.imshow(pred_mask, cmap='Reds', alpha=0.7)
plt.show()
7. 注意事项与进阶建议
-
内存管理:对非常大的WSI,推荐设置较小的
roi_size
,并分批次写磁盘掩膜(比如分块保存、最后合成)。 -
多通道输入:如果有多色通道,可以在transforms阶段合并或分别处理。
-
真实标签对齐:mask须和WSI空间完全对齐,level参数要保持一致。
-
采样策略:可以自定义只采肿瘤区域/背景区域patch,提升样本均衡。
8. 参考官方资源
可根据实际公开病理数据下载、特定数据预处理脚本、WSI特定采样技巧或者多级金字塔推理/自定义patch采样/推理结果高效保存