一、数据集预处理
1. 数据集准备
训练的数据集包括以下两种情况:自己数据集和公开数据集。
(1)自己数据集
标注自己的分割数据集:
使用labelme(支持大模型辅助)进行标注,标注好后的数据包括images文件夹(存放分割图像)和labels文件夹(存放分割后对应的.json文件)。
具体标注流程参考本作者的另外一篇文章:
https://blog.youkuaiyun.com/sunshineine/article/details/147231877?spm=1001.2014.3001.5502
数据转化(将.json格式标签转化成伪彩图图像):
使用labelme中提供的一个命令行工具labelme_export_json将.json格式标签转化成伪彩图图像,并每个图片文件夹中的label.png文件保存在一个总的文件夹中。
具体转化流程参考本作者的另外一篇文章:
https://blog.youkuaiyun.com/sunshineine/article/details/147422668
(2)公开数据集
2. 将分割数据集中的“图像 + GT + 嵌入” 数据打包
为了减少后续image encoder(比如 ViT-H)对每张图做 embedding 的计算成本,将分割数据集中的“图像 + GT + 嵌入” 数据提前打包保存好,可以让之后训练、测试、推理阶段更快。
具体打包流程参考本作者的另外一篇文章:
https://blog.youkuaiyun.com/sunshineine/article/details/147465912
3. 编写Dataset类
为了灵活地处理自定义的.npz数据格式和任务需求,需要构造自己的 Dataset 并继承 PyTorch 的 Dataset 类。
具体编写流程参考本作者的另外一篇文章:
https://blog.youkuaiyun.com/sunshineine/article/details/147472085?spm=1001.2014.3001.5502
二、创建SAM项目
1. 代码上传到服务器
在github中下载SAM的源代码到本地,下载好后的代码通过xftp8传送到服务器当中。
https://github.com/facebookresearch/segment-anything

2. 环境配置
后续操作在服务器中实现,创建虚拟环境sam,训练SAM需要配置好cuda、cudnn、torch、torchvision、torchaudio,具体版本如下:
python=3.8、cuda=11.8.0、cudnn=9.3.0.75、torch=2.0.0+cu118、torchaudio=2.0.0+cu118、torchvision=0.15.1+cu118
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_gnu conda-forge
brotli-python 1.1.0 py38h17151c0_1 conda-forge
bzip2 1.0.8 h4bc722e_7 conda-forge
ca-certificates 2025.1.31 hbcca054_0 conda-forge
certifi 2024.8.30 pyhd8ed1ab_0 conda-forge
charset-normalizer 3.4.0 pyhd8ed1ab_0 conda-forge
cmake 4.0.0 pypi_0 pypi
coloredlogs 15.0.1 pypi_0 pypi
contourpy 1.1.1 pypi_0 pypi
cuda-version 11.8 h70ddcb2_3 conda-forge
cudatoolkit 11.8.0 h4ba93d1_13 conda-forge
cudnn 9.3.0.75 hc149ed2_0 <unknown>
cycler 0.12.1 pypi_0 pypi
filelock 3.16.1 pyhd8ed1ab_0 conda-forge
flatbuffers 25.2.10 pypi_0 pypi
fonttools 4.56.0 pypi_0 pypi
freetype 2.13.3 h48d6fc4_0 conda-forge
humanfriendly 10.0 pypi_0 pypi
idna 3.10 pyhd8ed1ab_0 conda-forge
importlib-resources 6.4.5 pypi_0 pypi
jinja2 3.1.4 pyhd8ed1ab_0 conda-forge
kernel-headers_linux-64 3.10.0 he073ed8_18 conda-forge
kiwisolver 1.4.7 pypi_0 pypi
lcms2 2.15 h7f713cb_2 conda-forge
ld_impl_linux-64 2.43 h712a8e2_4 conda-forge
lerc 4.0.0 h27087fc_0 conda-forge
libabseil 20230802.1 cxx17_h59595ed_0 conda-forge
libblas 3.9.0 31_h59b9bed_openblas conda-forge
libcblas 3.9.0 31_he106b2a_openblas conda-forge
libdeflate 1.19 hd590300_0 conda-forge
libffi 3.4.6 h2dba641_0 conda-forge
libgcc 14.2.0 h767d61c_2 conda-forge
libgcc-ng 14.2.0 h69a702a_2 conda-forge
libgfortran 14.2.0 h69a702a_2 conda-forge
libgfortran5 14.2.0 hf1ad2bd_2 conda-forge
libgomp 14.2.0 h767d61c_2 conda-forge
libjpeg-turbo 2.1.5.1 hd590300_1 conda-forge
liblapack 3.9.0 31_h7ac8fdf_openblas conda-forge
liblzma 5.6.4 hb9d3cd8_0 conda-forge
liblzma-devel 5.6.4 hb9d3cd8_0 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libopenblas 0.3.29 pthreads_h94d23a6_0 conda-forge
libpng 1.6.47 h943b412_0 conda-forge
libprotobuf 4.24.4 hf27288f_0 conda-forge
libsqlite 3.49.1 hee588c1_2 conda-forge
libstdcxx 14.2.0 h8f9b012_2 conda-forge
libstdcxx-ng 14.2.0 h4852527_2 conda-forge
libtiff 4.6.0 h29866fb_1 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libuv 1.50.0 hb9d3cd8_0 conda-forge
libwebp-base 1.5.0 h851e524_0 conda-forge
libxcb 1.15 h0b41bf4_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libzlib 1.3.1 hb9d3cd8_2 conda-forge
lit 18.1.8 pypi_0 pypi
markupsafe 2.1.5 py38h01eb140_0 conda-forge
matplotlib 3.7.5 pypi_0 pypi
monai 1.3.2 pypi_0 pypi
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
ncurses 6.5 h2d0b736_3 conda-forge
networkx 3.1 pyhd8ed1ab_0 conda-forge
nomkl 1.0 h5ca1d4c_0 conda-forge
numpy 1.24.4 py38h59b608b_0 conda-forge
onnx 1.17.0 pypi_0 pypi
onnxruntime 1.19.2 pypi_0 pypi
opencv-python 4.11.0.86 pypi_0 pypi
openjpeg 2.5.2 h488ebb8_0 conda-forge
openssl 3.4.1 h7b32b05_0 conda-forge
packaging 24.2 pypi_0 pypi
pillow 10.0.1 py38h71741d6_1 conda-forge
pip 24.3.1 pyh8b19718_0 conda-forge
protobuf 5.29.4 pypi_0 pypi
pthread-stubs 0.4 hb9d3cd8_1002 conda-forge
pycocotools 2.0.7 pypi_0 pypi
pyparsing 3.1.4 pypi_0 pypi
pysocks 1.7.1 pyha2e5f31_6 conda-forge
python 3.8.20 h4a871b0_2_cpython conda-forge
python-dateutil 2.9.0.post0 pypi_0 pypi
python_abi 3.8 5_cp38 conda-forge
readline 8.2 h8c095d6_2 conda-forge
requests 2.32.3 pyhd8ed1ab_0 conda-forge
segment-anything 1.0 dev_0 <develop>
setuptools 75.3.0 pyhd8ed1ab_0 conda-forge
six 1.17.0 pypi_0 pypi
sleef 3.8 h1b44611_0 conda-forge
sympy 1.13.3 pyh04b8f61_4 conda-forge
sysroot_linux-64 2.17 h0157908_18 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
torch 2.0.0+cu118 pypi_0 pypi
torchaudio 2.0.0+cu118 pypi_0 pypi
torchvision 0.15.1+cu118 pypi_0 pypi
tqdm 4.67.1 pypi_0 pypi
triton 2.0.0 pypi_0 pypi
typing_extensions 4.12.2 pyha770c72_0 conda-forge
tzdata 2025b h78e105d_0 conda-forge
urllib3 2.2.2 pyhd8ed1ab_0 conda-forge
wheel 0.45.1 pyhd8ed1ab_0 conda-forge
xorg-libxau 1.0.12 hb9d3cd8_0 conda-forge
xorg-libxdmcp 1.1.5 hb9d3cd8_0 conda-forge
xz 5.6.4 hbcc6ac9_0 conda-forge
xz-gpl-tools 5.6.4 hbcc6ac9_0 conda-forge
xz-tools 5.6.4 hb9d3cd8_0 conda-forge
zipp 3.20.2 pypi_0 pypi
zstd 1.5.7 hb8e6e7a_2 conda-forge
有时候租的服务器中没有匹配的虚拟环境,则需要自己重新配置虚拟环境。
具体配置流程参考本作者的另外一篇文章:
https://mpbeta.youkuaiyun.com/mp_blog/creation/editor/146763545
3. 配置环境
在服务器中激活虚拟环境sam
conda activate sam
打开通过xftp8上传的sam项目的路径
cd sam-main
![]()
执行以下命令安装python包
pip install -e.

安装可选依赖项
pip install opencv-python pycocotools matplotlib onnxruntime onnx


4. 文件讲解
assets:存放资源文件
data:存放自己的数据集,包括训练集和验证集
images:存放预测图像
models:存放训练好的模型(权重文件)
notebook:存放官方提供的一些案例的脚本文件
scripts:存放数据增强(amg.py)和模型导出(export_onnx_model.py)的脚本文件
segment_anything:代码的核心内容,包括如下
modeling:主要包含图像编码器、提示编码器和掩码解码器三大部件的代码
三、模型训练
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import monai
from datetime import datetime
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from data.npzdataset import MyselfDataset
import random
# ================== 配置参数 ==================
npz_tr_path = 'D:/1/SAM/sam-main/data/data.npz'
model_type = 'vit_h'
checkpoint = 'D:/1/SAM/segment-anything-main/models/sam_vit_h_4b8939.pth'
model_save_root = './ultimate'
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = os.path.join(model_save_root, f"train_{timestamp}")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, 'vis_results'), exist_ok=True)
device = 'cuda:0'
num_epochs = 100
batch_size = 4
num_classes = 3 # 0: 背景, 1: rail, 2: obstacle
# ================== 模型与优化器 ==================
sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
sam_model.train()
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5)
seg_loss_dice = monai.losses.DiceCELoss(
to_onehot_y=True, softmax=True, include_background=True,
sigmoid=False, squared_pred=True, reduction='mean'
)
# ================== 数据加载 ==================
train_dataset = MyselfDataset(npz_tr_path)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
losses, accuracies = [], []
best_loss = float('inf')
# ================== 训练循环 ==================
for epoch in range(num_epochs):
epoch_loss, correct_pixels, total_pixels = 0, 0, 0
for step, batch in enumerate(tqdm(train_dataloader)):
img_embed, gt, box, pts, lbls, image, class_label = batch
img_embed = img_embed.view(img_embed.shape[0], 256, 64, 64).to(device)
gt = gt.to(device)
def forward_class(gt, box, pts, lbls):
sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
box_np = box.numpy()
box_trans = sam_trans.apply_boxes(box_np, (gt.shape[-2], gt.shape[-1]))
box_torch = torch.as_tensor(box_trans, dtype=torch.float, device=device)
pts_torch = torch.as_tensor(pts, dtype=torch.float, device=device)
lbls_torch = torch.as_tensor(lbls, dtype=torch.int, device=device)
pt = (pts_torch, lbls_torch)
sparse, dense = sam_model.prompt_encoder(points=pt, boxes=box_torch, masks=None)
mask_pred, _ = sam_model.mask_decoder(
image_embeddings=img_embed,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse,
dense_prompt_embeddings=dense,
multimask_output=False,
)
mask_pred = mask_pred.repeat(1, num_classes, 1, 1)
return mask_pred, gt
mask_pred, gt = forward_class(gt, box, pts, lbls)
loss = seg_loss_dice(mask_pred, gt)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
with torch.no_grad():
pred = torch.argmax(mask_pred, dim=1)
correct_pixels += (pred == gt.squeeze(1)).sum().item()
total_pixels += torch.numel(gt)
epoch_loss /= (step + 1)
accuracy = correct_pixels / total_pixels
losses.append(epoch_loss)
accuracies.append(accuracy)
print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {epoch_loss:.6f}, Accuracy: {accuracy:.4f}")
torch.save(sam_model.state_dict(), os.path.join(save_dir, "latest_model.pth"))
if epoch_loss < best_loss:
best_loss = epoch_loss
torch.save(sam_model.state_dict(), os.path.join(save_dir, "best_model.pth"))
# ================== 可视化 ==================
def smooth_curve(points, factor=0.8):
smoothed = []
for point in points:
if smoothed:
smoothed.append(smoothed[-1] * factor + point * (1 - factor))
else:
smoothed.append(point)
return smoothed
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), smooth_curve(losses), label='Smoothed Loss')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training Loss'); plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_loss.png'))
plt.close()
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), smooth_curve(accuracies), label='Smoothed Accuracy', color='orange')
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training Accuracy'); plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_accuracy.png'))
plt.close()
# ================== 可视化预测 ==================
print("Loading best model for visualization...")
sam_model.load_state_dict(torch.load(os.path.join(save_dir, "best_model.pth")))
sam_model.eval()
# Ground Truth 彩色掩码
def color_mask(gt):
color_map = np.array([
[0.0, 0.0, 0.0], # class 0
[0.0, 1.0, 0.0], # class 1
[1.0, 0.0, 0.0], # class 2
])
return color_map[gt]
# 预测掩码
def predict_overlay(img_embed, gt, box, pts, lbls):
img_embed = img_embed.view(1, 256, 64, 64).to(device)
gt = gt.unsqueeze(0).to(device)
sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
box = sam_trans.apply_boxes(box.numpy(), (gt.shape[-2], gt.shape[-1]))
box_torch = torch.as_tensor(box, dtype=torch.float, device=device).unsqueeze(0)
coords_torch = torch.as_tensor(pts, dtype=torch.float, device=device).unsqueeze(0)
labels_torch = torch.as_tensor(lbls, dtype=torch.int, device=device).unsqueeze(0)
pt = (coords_torch, labels_torch)
with torch.no_grad():
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=pt, boxes=box_torch, masks=None
)
pred, _ = sam_model.mask_decoder(
image_embeddings=img_embed,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
pred = pred.repeat(1, num_classes, 1, 1)
pred = torch.argmax(pred, dim=1).cpu().squeeze().numpy()
return pred
# 顺时针旋转 90 度
def rotate_clockwise_90(img):
return np.rot90(img, k=3)
# 叠加图像和掩码(图像已旋转)
def overlay_with_rotated_image(image_np, mask):
rotated_img = rotate_clockwise_90(image_np)
color_map = np.array([
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 0.0, 0.0],
])
color_mask = color_map[mask]
return 0.5 * rotated_img + 0.5 * color_mask
# 可视化3个样本
os.makedirs(os.path.join(save_dir, 'vis_results'), exist_ok=True)
sample_indices = random.sample(range(len(train_dataset)), 10)
for vis_idx, idx in enumerate(sample_indices):
img_embed, gt, box, pts, lbls, image, class_label = train_dataset[idx]
pred = predict_overlay(img_embed, gt, box, pts, lbls)
image_np = image.cpu().numpy()
if image_np.shape[0] == 3:
image_np = image_np.transpose(1, 2, 0)
image_np = np.clip(image_np, 0, 1)
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
axs[0].imshow(rotate_clockwise_90(image_np))
axs[0].set_title("Original (Rotated)")
axs[1].imshow(color_mask(gt.squeeze().numpy()))
axs[1].set_title("Ground Truth")
axs[2].imshow(pred, cmap='gray', vmin=0, vmax=2)
axs[2].set_title("Prediction")
axs[3].imshow(overlay_with_rotated_image(image_np, pred))
axs[3].set_title("Overlay (Image Rotated)")
for ax in axs:
ax.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'vis_results', f'vis_{vis_idx}.png'))
plt.close()
6118

被折叠的 条评论
为什么被折叠?



