SAM训练自己的数据集

一、数据集预处理

1. 数据集准备

训练的数据集包括以下两种情况:自己数据集和公开数据集。

(1)自己数据集

标注自己的分割数据集:

使用labelme(支持大模型辅助)进行标注,标注好后的数据包括images文件夹(存放分割图像)和labels文件夹(存放分割后对应的.json文件)。

具体标注流程参考本作者的另外一篇文章:

https://blog.youkuaiyun.com/sunshineine/article/details/147231877?spm=1001.2014.3001.5502

数据转化(将.json格式标签转化成伪彩图图像):

使用labelme中提供的一个命令行工具labelme_export_json将.json格式标签转化成伪彩图图像,并每个图片文件夹中的label.png文件保存在一个总的文件夹中。

具体转化流程参考本作者的另外一篇文章:

https://blog.youkuaiyun.com/sunshineine/article/details/147422668

(2)公开数据集

2. 将分割数据集中的“图像 + GT + 嵌入” 数据打包

为了减少后续image encoder(比如 ViT-H)对每张图做 embedding 的计算成本,将分割数据集中的“图像 + GT + 嵌入” 数据提前打包保存好,可以让之后训练、测试、推理阶段更快。

具体打包流程参考本作者的另外一篇文章:

https://blog.youkuaiyun.com/sunshineine/article/details/147465912

3. 编写Dataset类

为了灵活地处理自定义的.npz数据格式和任务需求,需要构造自己的 Dataset 并继承 PyTorch 的 Dataset 类。

具体编写流程参考本作者的另外一篇文章:

https://blog.youkuaiyun.com/sunshineine/article/details/147472085?spm=1001.2014.3001.5502

二、创建SAM项目

1. 代码上传到服务器

在github中下载SAM的源代码到本地,下载好后的代码通过xftp8传送到服务器当中。
https://github.com/facebookresearch/segment-anything

2. 环境配置

  后续操作在服务器中实现,创建虚拟环境sam,训练SAM需要配置好cuda、cudnn、torch、torchvision、torchaudio,具体版本如下:

python=3.8、cuda=11.8.0、cudnn=9.3.0.75、torch=2.0.0+cu118、torchaudio=2.0.0+cu118、torchvision=0.15.1+cu118

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
brotli-python             1.1.0            py38h17151c0_1    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
ca-certificates           2025.1.31            hbcca054_0    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
charset-normalizer        3.4.0              pyhd8ed1ab_0    conda-forge
cmake                     4.0.0                    pypi_0    pypi
coloredlogs               15.0.1                   pypi_0    pypi
contourpy                 1.1.1                    pypi_0    pypi
cuda-version              11.8                 h70ddcb2_3    conda-forge
cudatoolkit               11.8.0              h4ba93d1_13    conda-forge
cudnn                     9.3.0.75             hc149ed2_0    <unknown>
cycler                    0.12.1                   pypi_0    pypi
filelock                  3.16.1             pyhd8ed1ab_0    conda-forge
flatbuffers               25.2.10                  pypi_0    pypi
fonttools                 4.56.0                   pypi_0    pypi
freetype                  2.13.3               h48d6fc4_0    conda-forge
humanfriendly             10.0                     pypi_0    pypi
idna                      3.10               pyhd8ed1ab_0    conda-forge
importlib-resources       6.4.5                    pypi_0    pypi
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
kernel-headers_linux-64   3.10.0              he073ed8_18    conda-forge
kiwisolver                1.4.7                    pypi_0    pypi
lcms2                     2.15                 h7f713cb_2    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_4    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20230802.1      cxx17_h59595ed_0    conda-forge
libblas                   3.9.0           31_h59b9bed_openblas    conda-forge
libcblas                  3.9.0           31_he106b2a_openblas    conda-forge
libdeflate                1.19                 hd590300_0    conda-forge
libffi                    3.4.6                h2dba641_0    conda-forge
libgcc                    14.2.0               h767d61c_2    conda-forge
libgcc-ng                 14.2.0               h69a702a_2    conda-forge
libgfortran               14.2.0               h69a702a_2    conda-forge
libgfortran5              14.2.0               hf1ad2bd_2    conda-forge
libgomp                   14.2.0               h767d61c_2    conda-forge
libjpeg-turbo             2.1.5.1              hd590300_1    conda-forge
liblapack                 3.9.0           31_h7ac8fdf_openblas    conda-forge
liblzma                   5.6.4                hb9d3cd8_0    conda-forge
liblzma-devel             5.6.4                hb9d3cd8_0    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libopenblas               0.3.29          pthreads_h94d23a6_0    conda-forge
libpng                    1.6.47               h943b412_0    conda-forge
libprotobuf               4.24.4               hf27288f_0    conda-forge
libsqlite                 3.49.1               hee588c1_2    conda-forge
libstdcxx                 14.2.0               h8f9b012_2    conda-forge
libstdcxx-ng              14.2.0               h4852527_2    conda-forge
libtiff                   4.6.0                h29866fb_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.50.0               hb9d3cd8_0    conda-forge
libwebp-base              1.5.0                h851e524_0    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
lit                       18.1.8                   pypi_0    pypi
markupsafe                2.1.5            py38h01eb140_0    conda-forge
matplotlib                3.7.5                    pypi_0    pypi
monai                     1.3.2                    pypi_0    pypi
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
ncurses                   6.5                  h2d0b736_3    conda-forge
networkx                  3.1                pyhd8ed1ab_0    conda-forge
nomkl                     1.0                  h5ca1d4c_0    conda-forge
numpy                     1.24.4           py38h59b608b_0    conda-forge
onnx                      1.17.0                   pypi_0    pypi
onnxruntime               1.19.2                   pypi_0    pypi
opencv-python             4.11.0.86                pypi_0    pypi
openjpeg                  2.5.2                h488ebb8_0    conda-forge
openssl                   3.4.1                h7b32b05_0    conda-forge
packaging                 24.2                     pypi_0    pypi
pillow                    10.0.1           py38h71741d6_1    conda-forge
pip                       24.3.1             pyh8b19718_0    conda-forge
protobuf                  5.29.4                   pypi_0    pypi
pthread-stubs             0.4               hb9d3cd8_1002    conda-forge
pycocotools               2.0.7                    pypi_0    pypi
pyparsing                 3.1.4                    pypi_0    pypi
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.8.20          h4a871b0_2_cpython    conda-forge
python-dateutil           2.9.0.post0              pypi_0    pypi
python_abi                3.8                      5_cp38    conda-forge
readline                  8.2                  h8c095d6_2    conda-forge
requests                  2.32.3             pyhd8ed1ab_0    conda-forge
segment-anything          1.0                       dev_0    <develop>
setuptools                75.3.0             pyhd8ed1ab_0    conda-forge
six                       1.17.0                   pypi_0    pypi
sleef                     3.8                  h1b44611_0    conda-forge
sympy                     1.13.3             pyh04b8f61_4    conda-forge
sysroot_linux-64          2.17                h0157908_18    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
torch                     2.0.0+cu118              pypi_0    pypi
torchaudio                2.0.0+cu118              pypi_0    pypi
torchvision               0.15.1+cu118             pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
triton                    2.0.0                    pypi_0    pypi
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2025b                h78e105d_0    conda-forge
urllib3                   2.2.2              pyhd8ed1ab_0    conda-forge
wheel                     0.45.1             pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.12               hb9d3cd8_0    conda-forge
xorg-libxdmcp             1.1.5                hb9d3cd8_0    conda-forge
xz                        5.6.4                hbcc6ac9_0    conda-forge
xz-gpl-tools              5.6.4                hbcc6ac9_0    conda-forge
xz-tools                  5.6.4                hb9d3cd8_0    conda-forge
zipp                      3.20.2                   pypi_0    pypi
zstd                      1.5.7                hb8e6e7a_2    conda-forge

有时候租的服务器中没有匹配的虚拟环境,则需要自己重新配置虚拟环境。

具体配置流程参考本作者的另外一篇文章:

https://mpbeta.youkuaiyun.com/mp_blog/creation/editor/146763545

3. 配置环境

在服务器中激活虚拟环境sam

conda activate sam

打开通过xftp8上传的sam项目的路径

cd sam-main

执行以下命令安装python包

pip install -e.

安装可选依赖项

pip install opencv-python pycocotools matplotlib onnxruntime onnx

4. 文件讲解

assets:存放资源文件
data:存放自己的数据集,包括训练集和验证集
images:存放预测图像
models:存放训练好的模型(权重文件)
notebook:存放官方提供的一些案例的脚本文件
scripts:存放数据增强(amg.py)和模型导出(export_onnx_model.py)的脚本文件
segment_anything:代码的核心内容,包括如下
        modeling:主要包含图像编码器、提示编码器和掩码解码器三大部件的代码

三、模型训练

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import monai
from datetime import datetime
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from data.npzdataset import MyselfDataset
import random

# ================== 配置参数 ==================
npz_tr_path = 'D:/1/SAM/sam-main/data/data.npz'
model_type = 'vit_h'
checkpoint = 'D:/1/SAM/segment-anything-main/models/sam_vit_h_4b8939.pth'
model_save_root = './ultimate'
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = os.path.join(model_save_root, f"train_{timestamp}")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, 'vis_results'), exist_ok=True)
device = 'cuda:0'
num_epochs = 100
batch_size = 4
num_classes = 3  # 0: 背景, 1: rail, 2: obstacle

# ================== 模型与优化器 ==================
sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
sam_model.train()
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5)
seg_loss_dice = monai.losses.DiceCELoss(
    to_onehot_y=True, softmax=True, include_background=True,
    sigmoid=False, squared_pred=True, reduction='mean'
)

# ================== 数据加载 ==================
train_dataset = MyselfDataset(npz_tr_path)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

losses, accuracies = [], []
best_loss = float('inf')

# ================== 训练循环 ==================
for epoch in range(num_epochs):
    epoch_loss, correct_pixels, total_pixels = 0, 0, 0

    for step, batch in enumerate(tqdm(train_dataloader)):
        img_embed, gt, box, pts, lbls, image, class_label = batch

        img_embed = img_embed.view(img_embed.shape[0], 256, 64, 64).to(device)
        gt = gt.to(device)

        def forward_class(gt, box, pts, lbls):
            sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
            box_np = box.numpy()
            box_trans = sam_trans.apply_boxes(box_np, (gt.shape[-2], gt.shape[-1]))
            box_torch = torch.as_tensor(box_trans, dtype=torch.float, device=device)
            pts_torch = torch.as_tensor(pts, dtype=torch.float, device=device)
            lbls_torch = torch.as_tensor(lbls, dtype=torch.int, device=device)
            pt = (pts_torch, lbls_torch)

            sparse, dense = sam_model.prompt_encoder(points=pt, boxes=box_torch, masks=None)
            mask_pred, _ = sam_model.mask_decoder(
                image_embeddings=img_embed,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse,
                dense_prompt_embeddings=dense,
                multimask_output=False,
            )
            mask_pred = mask_pred.repeat(1, num_classes, 1, 1)
            return mask_pred, gt

        mask_pred, gt = forward_class(gt, box, pts, lbls)

        loss = seg_loss_dice(mask_pred, gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        with torch.no_grad():
            pred = torch.argmax(mask_pred, dim=1)
            correct_pixels += (pred == gt.squeeze(1)).sum().item()
            total_pixels += torch.numel(gt)

    epoch_loss /= (step + 1)
    accuracy = correct_pixels / total_pixels
    losses.append(epoch_loss)
    accuracies.append(accuracy)

    print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {epoch_loss:.6f}, Accuracy: {accuracy:.4f}")

    torch.save(sam_model.state_dict(), os.path.join(save_dir, "latest_model.pth"))
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(sam_model.state_dict(), os.path.join(save_dir, "best_model.pth"))

# ================== 可视化 ==================
def smooth_curve(points, factor=0.8):
    smoothed = []
    for point in points:
        if smoothed:
            smoothed.append(smoothed[-1] * factor + point * (1 - factor))
        else:
            smoothed.append(point)
    return smoothed

plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), smooth_curve(losses), label='Smoothed Loss')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training Loss'); plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_loss.png'))
plt.close()

plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), smooth_curve(accuracies), label='Smoothed Accuracy', color='orange')
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training Accuracy'); plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_accuracy.png'))
plt.close()

# ================== 可视化预测 ==================
print("Loading best model for visualization...")
sam_model.load_state_dict(torch.load(os.path.join(save_dir, "best_model.pth")))
sam_model.eval()

# Ground Truth 彩色掩码
def color_mask(gt):
    color_map = np.array([
        [0.0, 0.0, 0.0],  # class 0
        [0.0, 1.0, 0.0],  # class 1
        [1.0, 0.0, 0.0],  # class 2
    ])
    return color_map[gt]

# 预测掩码
def predict_overlay(img_embed, gt, box, pts, lbls):
    img_embed = img_embed.view(1, 256, 64, 64).to(device)
    gt = gt.unsqueeze(0).to(device)
    sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
    box = sam_trans.apply_boxes(box.numpy(), (gt.shape[-2], gt.shape[-1]))
    box_torch = torch.as_tensor(box, dtype=torch.float, device=device).unsqueeze(0)
    coords_torch = torch.as_tensor(pts, dtype=torch.float, device=device).unsqueeze(0)
    labels_torch = torch.as_tensor(lbls, dtype=torch.int, device=device).unsqueeze(0)
    pt = (coords_torch, labels_torch)

    with torch.no_grad():
        sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
            points=pt, boxes=box_torch, masks=None
        )
        pred, _ = sam_model.mask_decoder(
            image_embeddings=img_embed,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )
        pred = pred.repeat(1, num_classes, 1, 1)
        pred = torch.argmax(pred, dim=1).cpu().squeeze().numpy()
    return pred

# 顺时针旋转 90 度
def rotate_clockwise_90(img):
    return np.rot90(img, k=3)

# 叠加图像和掩码(图像已旋转)
def overlay_with_rotated_image(image_np, mask):
    rotated_img = rotate_clockwise_90(image_np)
    color_map = np.array([
        [0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [1.0, 0.0, 0.0],
    ])
    color_mask = color_map[mask]
    return 0.5 * rotated_img + 0.5 * color_mask

# 可视化3个样本
os.makedirs(os.path.join(save_dir, 'vis_results'), exist_ok=True)

sample_indices = random.sample(range(len(train_dataset)), 10)
for vis_idx, idx in enumerate(sample_indices):
    img_embed, gt, box, pts, lbls, image, class_label = train_dataset[idx]
    pred = predict_overlay(img_embed, gt, box, pts, lbls)

    image_np = image.cpu().numpy()
    if image_np.shape[0] == 3:
        image_np = image_np.transpose(1, 2, 0)
    image_np = np.clip(image_np, 0, 1)

    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(rotate_clockwise_90(image_np))
    axs[0].set_title("Original (Rotated)")

    axs[1].imshow(color_mask(gt.squeeze().numpy()))
    axs[1].set_title("Ground Truth")

    axs[2].imshow(pred, cmap='gray', vmin=0, vmax=2)
    axs[2].set_title("Prediction")

    axs[3].imshow(overlay_with_rotated_image(image_np, pred))
    axs[3].set_title("Overlay (Image Rotated)")

    for ax in axs:
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'vis_results', f'vis_{vis_idx}.png'))
    plt.close()




        

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值