[pytorch] nnUnet for 2D Images Segmentation

A tutorial on how to use nnUnet for 2D image segmentation, using MICCAI2022 Challenge: GOALS as an example. Currently my best performing method.
paper:nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation
code: nnunet

Installation of nnUnet

PYTORCH is necessary!
For use as integrative framework (this will create a copy of the nnU-Net code on your computer so that you can modify it as needed):

git clone https://github.com/MIC-DKFZ/nnUNet.git
cd nnUNet
pip install -e .

After you have installed these, each of your operations on nnUNet will start with nnUNet_ in the command line, which represents the command for your nnUNet to start working.

The next step is to create the data storage directory.

  1. Go into the nnUNet folder you created earlier and create a folder named DATASET, DATASET is where we will put the data next;

在这里插入图片描述

  1. Go to the created DATASET folder and create the following three folders: nnUNet_preprocessed, nnUNet_raw, and nnUNet_trained_models. The first is used to store the preprocessed data of the original data, the second is used to store the original data you want to train, and the third is used to store the training results.
    在这里插入图片描述
  2. Enter the above folder nnUNet_raw, create the following two folders, nnUNet_cropped_data, nnUNet_raw_data, the right side is the original data, the left side is the cropped data.
    在这里插入图片描述
  3. Go to the right folder nnUNet_raw_data, and create a folder named Task888_GOALS (Explanation: The data format of this nnUnet is fixed, Task001_BloodVessel consists of Task+ID+data name, you can name the digital ID of this task arbitrarily, such as you To split the heart, you can name it Task001_Heart, for example, if you want to split the kidney, you can name it Task002_Kidney, provided it must follow this format)
    在这里插入图片描述

Environment configuration

nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to set a few of environment variables.
Setting up Paths
The method above sets the paths permanently (until you delete the lines from your .bashrc) on your system. If you wish to set them only temporarily, you can run the export commands in your terminal:

export nnUNet_raw_data_base="/home/liyihao/LI/nnUNet/DATASET/nnUNet_raw"
export nnUNet_preprocessed="/home/liyihao/LI/nnUNet/DATASET/nnUNet_preprocessed"
export RESULTS_FOLDER="/home/liyihao/LI/nnUNet/DATASET/nnUNet_trained_models"

Data configuration

This is the form of the raw data
在这里插入图片描述
We need to convert it to the format required by nnunet
First of all, we need to put the training set, gt, and test set in these three files, and also bring the json file (the file name is as shown in the figure, it cannot be changed)
在这里插入图片描述
The data format of nnUnet is fixed, Task888_GOALS consists of Task+ID+data name, imagesTr is training data, imagesTs is test data, labelsTr is the label of training data, data sample la_003_0000.nii.gz consists of case sample name + modal flag + .nii.gz, different modals are distinguished by 0000/0001/0002/0003.
Example tree structure:

nnUNet_raw_data_base/nnUNet_raw_data/Task002_Heart
├── dataset.json
├── imagesTr
│   ├── la_003_0000.nii.gz
│   ├── la_004_0000.nii.gz
│   ├── ...
├── imagesTs
│   ├── la_001_0000.nii.gz
│   ├── la_002_0000.nii.gz
│   ├── ...
└── labelsTr
    ├── la_003.nii.gz
    ├── la_004.nii.gz
    ├── ...

Our original 2-dimensional data is RGB three-channel, we can regard the RGB three-channel data as 3 modes, extract the data of different channels respectively, convert the shape to (1, width, height), and save it as 3-dimensional data.

import os
import random
from tqdm import tqdm
import SimpleITK as sitk
import cv2
import numpy as np

root = '/home/liyihao/LI/GOALS'
base_image = root + '/GOALS2022-Train/Train/Image'
base_gt = root + '/GOALS2022-Train/Train/Layer_Masks'
base_test = root + '/GOALS2022-Validation/GOALS2022-Validation/Image'

target_labelsTr = '/home/liyihao/LI/nnUNet/DATASET/nnUNet_raw/nnUNet_raw_data/Task888_GOALS/labelsTr/'
target_imagesTr = '/home/liyihao/LI/nnUNet/DATASET/nnUNet_raw/nnUNet_raw_data/Task888_GOALS/imagesTr/'

# train set
savepath_img = target_imagesTr
savepath_mask = target_labelsTr

img_path = base_image
mask_path = base_gt
ImgList = os.listdir(img_path)
print(ImgList)

with tqdm(ImgList, desc="conver") as pbar:
    for name in pbar:
        #print(name)
        Img = cv2.imread(os.path.join(img_path, name))
        #print(Img.shape)
        gt_img = cv2.imread(os.path.join(mask_path, name))
        gt_img[gt_img == 0] = 3
        gt_img[gt_img == 80] = 1
        gt_img[gt_img == 160] = 2
        gt_img[gt_img == 255] = 0
        gt_img = gt_img[:,:,1].astype(np.uint8)
        #print(gt_img.shape)
        Img_Transposed = np.transpose(Img, (2, 0, 1))
        Img_0 = Img_Transposed[0].reshape(1, Img_Transposed[0].shape[0], Img_Transposed[0].shape[1])
        Img_1 = Img_Transposed[1].reshape(1, Img_Transposed[1].shape[0], Img_Transposed[1].shape[1])
        Img_2 = Img_Transposed[2].reshape(1, Img_Transposed[2].shape[0], Img_Transposed[2].shape[1])
        gt_img = gt_img.reshape(1, gt_img.shape[0], gt_img.shape[1])
        #print(np.unique(gt_img))
        Img_0_name = 'GOALS_'+ str(name.split('.')[0]) + '_0000.nii.gz'
        Img_1_name = 'GOALS_'+ str(name.split('.')[0]) + '_0001.nii.gz'
        Img_2_name = 'GOALS_'+ str(name.split('.')[0]) + '_0002.nii.gz'
        #print(str(name.split('.')[0]))
        gt_img_name = 'GOALS_'+ str(name.split('.')[0]) + '.nii.gz'
        
        Img_0_nii = sitk.GetImageFromArray(Img_0)
        Img_1_nii = sitk.GetImageFromArray(Img_1)
        Img_2_nii = sitk.GetImageFromArray(Img_2)
        gt_img_nii = sitk.GetImageFromArray(gt_img)
        
        sitk.WriteImage(Img_0_nii, os.path.join(savepath_img, Img_0_name))
        sitk.WriteImage(Img_1_nii, os.path.join(savepath_img, Img_1_name))
        sitk.WriteImage(Img_2_nii, os.path.join(savepath_img, Img_2_name))
        sitk.WriteImage(gt_img_nii, os.path.join(savepath_mask, gt_img_name))

# test
img_path = base_test
ImgList = os.listdir(img_path)
print(ImgList)

savepath_img = '/home/liyihao/LI/nnUNet/DATASET/nnUNet_raw/nnUNet_raw_data/Task888_GOALS/imagesTs/'
with tqdm(ImgList, desc="conver") as pbar:
    for name in pbar:
        #print(name)
        Img = cv2.imread(os.path.join(img_path, name))

        #print(gt_img.shape)
        Img_Transposed = np.transpose(Img, (2, 0, 1))
        Img_0 = Img_Transposed[0].reshape(1, Img_Transposed[0].shape[0], Img_Transposed[0].shape[1])
        Img_1 = Img_Transposed[1].reshape(1, Img_Transposed[1].shape[0], Img_Transposed[1].shape[1])
        Img_2 = Img_Transposed[2].reshape(1, Img_Transposed[2].shape[0], Img_Transposed[2].shape[1])
        #print(np.unique(gt_img))
        Img_0_name = 'GOALS_'+ str(name.split('.')[0]) + '_0000.nii.gz'
        Img_1_name = 'GOALS_'+ str(name.split('.')[0]) + '_0001.nii.gz'
        Img_2_name = 'GOALS_'+ str(name.split('.')[0]) + '_0002.nii.gz'

        
        Img_0_nii = sitk.GetImageFromArray(Img_0)
        Img_1_nii = sitk.GetImageFromArray(Img_1)
        Img_2_nii = sitk.GetImageFromArray(Img_2)
        
        sitk.WriteImage(Img_0_nii, os.path.join(savepath_img, Img_0_name))
        sitk.WriteImage(Img_1_nii, os.path.join(savepath_img, Img_1_name))
        sitk.WriteImage(Img_2_nii, os.path.join(savepath_img, Img_2_name))

在这里插入图片描述
在这里插入图片描述

make json file :

import glob
import os
import re
import json
from collections import OrderedDict

train_list = os.listdir('/home/liyihao/LI/nnUNet/DATASET/nnUNet_raw/nnUNet_raw_data/Task888_GOALS/labelsTr/')
print(train_list)
test_list = os.listdir('/home/liyihao/LI/GOALS/'+'/GOALS2022-Validation/GOALS2022-Validation/Image')
test_patient = []
for i in test_list:
    patient = i.split('.')[0]
    name = 'GOALS_'+str(patient)+ '.nii.gz'
    test_patient.append(name)
print(test_patient)

json_dict = OrderedDict()
json_dict['name'] = "GOALS"
json_dict['description'] = "LI Yihao copyright"
json_dict['tensorImageSize'] = "3D"
json_dict['reference'] = "see GOALS2022"
json_dict['licence'] = "see GOALS2022"
json_dict['release'] = "0.0"
json_dict['modality'] = {
    "0": "R",
    "1": "G",
    "2": "B"
}
json_dict['labels'] = {
    "0": "background",
    "1": "GCIPL",
    "2": "CHOROID",
    "3": "RNFL"
}

json_dict['numTraining'] = len(train_list)
json_dict['numTest'] = len(test_patient)
json_dict['training'] = [{'image': "./imagesTr/%s" % i, "label": "./labelsTr/%s" % i} for i in
                         train_list]
json_dict['test'] = [{'image': "./imagesTs/%s" % i} for i in test_patient]

with open('dataset.json', 'w', encoding='utf-8') as f:
    json.dump(json_dict, f, ensure_ascii=False, indent=4)

在这里插入图片描述
I put all processed data in

/home/shared/GOALS_challenge/LI/nnUnet

Train

data preprocessing
Note that since this is a 2D dataset there is no need to run preprocessing for 3D U-Nets. You should therefore run the
nnUNet_plan_and_preprocess command like this:
ref : Task120_Massachusetts_RoadSegm.py

nnUNet_plan_and_preprocess -t 888 -pl3d None

The GPU memory requirement is greater than 11G. Run the following commands one by one during training, and each fold of cross-validation will cost 16+ hours.

nnUNet_train 2d nnUNetTrainerV2 Task888_GOALS 0 --npz
nnUNet_train 2d nnUNetTrainerV2 Task888_GOALS 1 --npz
nnUNet_train 2d nnUNetTrainerV2 Task888_GOALS 2 --npz
nnUNet_train 2d nnUNetTrainerV2 Task888_GOALS 3 --npz
nnUNet_train 2d nnUNetTrainerV2 Task888_GOALS 4 --npz

Inference

After running the 5-fold cross-validation, the best configuration can be determined. The following 888 is the ID of the Task.

nnUNet_find_best_configuration -m 2d -t 888 

Generate the following files
在这里插入图片描述
Then open the above txt file, which will generate the Inference method:

nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr nnUNetTrainerV2 -ctr nnUNetTrainerV2CascadeFullRes -m 2d -p nnUNetPlansv2.1 -t Task888_GOALS

for me:

nnUNet_predict -i /home/liyihao/LI/nnUNet/DATASET/nnUNet_raw/nnUNet_raw_data/Task888_GOALS/imagesTs/ -o ./pre_nnunet/ -tr nnUNetTrainerV2 -ctr nnUNetTrainerV2CascadeFullRes -m 2d -p nnUNetPlansv2.1 -t Task888_GOALS

Because the result of the model prediction is the nii.gz file, it needs to be converted into 2D image data

import os
import random
from tqdm import tqdm
import SimpleITK as sitk
import cv2
import numpy as np
import matplotlib.pyplot as plt

img_dir = './pre_nnunet/'
img_list = [i for i in os.listdir(img_dir) if ".nii.gz" in i]
print(img_list)

with tqdm(img_list, desc="conver") as pbar:
    for name in pbar:
        print(name)
        image = sitk.ReadImage(os.path.join(img_dir, name))
        image = sitk.GetArrayFromImage(image)[0]
        print(image.shape)
        image[image == 0] = 255
        image[image == 1] = 80
        image[image == 2] = 160
        image[image == 3] = 0
        #print(np.unique(image))
        #print(abc)
        cv2.imwrite(os.path.join('output', name.split(".")[0].split("_")[1]+".png"), image)
<think>好的,用户想了解如何使用PyTorch进行农田图像分割。首先,我需要整理已有的信息。根据引用[1],用户有2890张农田分割数据集,标注为YOLO和VOC格式。引用[2]提到了将GeoTIFF转换为PNG的方法,这可能对预处理数据有用。引用[3]和[4]涉及上采样方法和DeeplabV3+模型,这些是语义分割的关键部分。 接下来,我需要构建一个分步骤的教程。首先应该是数据准备,将VOC格式的数据转换为PyTorch可用的Dataset。然后模型搭建,选用DeeplabV3+,参考引用[4]提供的代码。训练过程需要定义损失函数和优化器,可能用交叉熵和Adam。数据增强部分可以加入旋转、翻转等。最后是推理和可视化。 还要注意用户提到的YOLO和VOC标注,可能需要解释如何处理这两种格式,但VOC更适合语义分割,所以重点在VOC。另外,数据预处理可能需要用到引用[2]中的代码,将GeoTIFF转为PNG,但用户的数据可能已经是常规格式了,所以可能不需要,但可以作为可选步骤。 在代码示例部分,需要展示Dataset类、模型初始化、训练循环等。可能还要提到上采样方法,如反卷积,参考引用[3]。最后,生成相关问题,确保覆盖数据预处理、模型选择、评估指标和优化策略等方面。 需要检查所有LaTeX格式是否正确,行内公式用$...$,独立公式用$$...$$,确保没有使用错误。同时引用标注要正确对应,比如数据转换引用[^2],模型引用[^4],上采样引用[^3]。 可能还需要考虑用户的实际应用场景,比如农田分割可能涉及大尺寸遥感图像,处理时需要注意内存问题,可以建议分块处理或使用更大的batch size。另外,数据量2890张可能较小,可以建议数据增强或迁移学习。 最后,确保回答结构清晰,逐步引导用户完成整个流程,并给出相关问题的链接,帮助用户深入了解更多细节。</think>以下是使用PyTorch实现农田图像分割的步骤详解: ### 1. 数据准备与预处理 - **数据格式转换**:若原始数据为GeoTIFF格式,需先转换为PNG格式(引用代码[^2]) ```python # 示例:批量转换GeoTIFF import os from pathlib import Path geotiff_dir = Path('geotiff/') png_dir = Path('png/') for file in geotiff_dir.glob('*.tif'): with rasterio.open(file) as dataset: image = dataset.read([1, 2, 3]) image = np.transpose(image, (1, 2, 0)) plt.imsave(png_dir/file.stem + '.png', image) ``` - **VOC格式处理**:创建符合PyTorch要求的Dataset类 ```python from torch.utils.data import Dataset import torchvision.transforms as T class FarmDataset(Dataset): def __init__(self, img_dir, mask_dir): self.img_paths = sorted(Path(img_dir).glob("*.png")) self.mask_paths = sorted(Path(mask_dir).glob("*.png")) self.transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.img_paths) def __getitem__(self, idx): img = Image.open(self.img_paths[idx]).convert("RGB") mask = Image.open(self.mask_paths[idx]).convert("L") # 灰度模式 return self.transform(img), T.ToTensor()(mask) ``` ### 2. 模型搭建 推荐使用DeeplabV3+(引用模型[^4]): ```python import torch from torchvision.models.segmentation import deeplabv3_resnet50 model = deeplabv3_resnet50(pretrained=True) model.classifier[4] = torch.nn.Conv2d(256, 1, kernel_size=(1, 1)) # 修改输出通道数为1 ``` ### 3. 训练流程 ```python # 初始化 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) criterion = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # 数据加载 train_dataset = FarmDataset('train_images', 'train_masks') train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) # 训练循环 for epoch in range(50): for images, masks in train_loader: images = images.to(device) masks = masks.to(device) outputs = model(images)['out'] loss = criterion(outputs, masks) optimizer.zero_grad() loss.backward() optimizer.step() ``` ### 4. 数据增强策略 在Dataset类中添加增强操作: ```python self.transform = T.Compose([ T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.3), T.RandomRotation(30), T.ColorJitter(brightness=0.2, contrast=0.2), T.ToTensor(), T.Normalize(mean=[...], std=[...]) ]) ``` ### 5. 结果可视化 ```python def visualize_prediction(image, mask, pred): plt.figure(figsize=(15,5)) plt.subplot(1,3,1).imshow(image.permute(1,2,0).cpu().numpy()) plt.subplot(1,3,2).imshow(mask.squeeze().cpu().numpy(), cmap='gray') plt.subplot(1,3,3).imshow(torch.sigmoid(pred).squeeze().detach().cpu().numpy() > 0.5, cmap='gray') ``` ### 关键数学原理 语义分割本质是像素级分类任务,损失函数可表示为: $$L = -\frac{1}{N}\sum_{i=1}^N [y_i\log p_i + (1-y_i)\log(1-p_i)]$$ 其中$y_i$为真实标签,$p_i$为预测概率值[^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值