TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)

该博客介绍了如何将数据集转换为TransUNet模型所需的格式,包括图像和标签的合并、预训练权重的下载、代码的修改等步骤,以适应五类汽车部件分割任务。详细讲解了数据集的结构、npz文件的生成、训练参数配置以及trainer.py的调整。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文章已经生成可运行项目,

***************************************************

码字不易,收藏之余,别忘了给我点个赞吧!

***************************************************

---------Start

官方代码:https://github.com/Beckschen/TransUNet

目的:训练5个类别的汽车部件分割任务(测试在另一篇博客中)

优快云数据集免费下载

实现效果:

在这里插入图片描述

在这里插入图片描述

1. github下载代码,并解压。

在这里插入图片描述
在这里插入图片描述

项目里的文件可能跟你下载的不一样,不急后面会讲到!

在这里插入图片描述

2. 配置数据集(尽最大努力还原官方数据集的格式)。

通常自己手上的数据集分images和labels文件夹,分别存放着原始图像和对应的mask图像,如下图所示; mask图像中的像素有0,1,2,3,4 分别代表背景,车身,轮子,车灯,窗户,一共五个类别,所以这里显示全黑色,肉眼看不出差别!通过阅读官方读取数据的代码,我们需要将一张图像和其对应的标签合并转化成一个.npz文件.

在这里插入图片描述
在这里插入图片描述
官方数据集格式,data文件夹,Synapse文件夹,test_vol_h5文件夹,train_npz文件夹手动创建!
在这里插入图片描述

转化数据集的代码如下,会将images中的图像和labels中的标签生成一个.npz文件。

def npz():
    #图像路径
    path = r'G:\dataset\car-segmentation\train\images\*.png'
    #项目中存放训练所用的npz文件路径
    path2 = r'G:\dataset\Unet\TransUnet-ori\data\Synapse\train_npz\\'
    for i,img_path in enumerate(glob.glob(path)):
    	#读入图像
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        #读入标签
        label_path = img_path.replace('images','labels')
        label = cv2.imread(label_path,flags=0)
		#保存npz
        np.savez(path2+str(i),image=image,label=label)
        print('------------',i)

    # 加载npz文件
    # data = np.load(r'G:\dataset\Unet\Swin-Unet-ori\data\Synapse\train_npz\0.npz', allow_pickle=True)
    # image, label = data['image'], data['label']

    print('ok')

生成的文件在 data\Synapse\train_npz文件夹中,如下图,也可以自己定义生成的路径,然后把文件复制到data\Synapse\train_npz文件中。

在这里插入图片描述
data\Synapse\train_npz文件夹中存放的是训练集样本,按照同样的方式生成测试集样本,存放在data\Synapse\test_vol_h5文件夹中。
在这里插入图片描述
我的训练集203个样本,测试集3个样本。npz文件生成完成之后,找到train.txt和test_vol.txt,手动将文件里面的内容清空,split_data.py这个文件直接无视。自己写一个函数读取train_npz中所有的文件名称,然后将文件名称写入train.txt文件,一个名称一行,如下图所示。同理可完成test_vol.txt文件制作。
在这里插入图片描述

在这里插入图片描述
至此,数据集制作完毕!!!代码会先去train.txt文件中读取训练样本的名称,然后根据名称再去train_npz文件夹下读取npz文件。所以每一步都很重要,必须正确!

3. 下载预训练权重

官方下载地址

优快云下载地址[推荐]

进入网站后,点击imagenet21k文件夹。
在这里插入图片描述
下载这个权重文件即可。
在这里插入图片描述
手动创建如下多个文件夹,存放刚刚下载完毕的权重,注意名称跟我的保持一致!
在这里插入图片描述
至此,预训练权重已下载完毕。

4. 修改读取文件的方法

找到datasets/dataset_synapse.py文件中的Synapse_dataset类,修改__getitem__函数。

 def __getitem__(self, idx):
        if self.split == "train":
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir+"/"+slice_name+'.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir+"/"+slice_name+'.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
            image = torch.from_numpy(image.astype(np.float32))
            image = image.permute(2,0,1)
            label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample

找到datasets/dataset_synapse.py文件中的RandomGenerator类,修改__call__函数。

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        x, y,_ = image.shape
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3)  # why not 3?
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        image = torch.from_numpy(image.astype(np.float32))
        image = image.permute(2,0,1)
        label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label.long()}
        return sample

至此,数据读取的部分已经修改完毕!

5. 配置训练参数

认真检查各个参数是否正确,这里的路径都是 ‘./’(当前目录下),不是"…/",训练时,batch_size通常大于1,我这里设置有误!类别数可根据你的任务定!

在这里插入图片描述
图片大小设置,越大越耗显存。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

6. 修改trainer.py文件

设置trainer.py文件中的DataLoader函数中的num_workers=0
在这里插入图片描述

至此,所有代码修改完毕!

总结:以上修改内容针对彩色图像的分割任务, 由于仅文字表述某些操作存在局限性,故只能简略应答,有任何问题可下方留言评论。

本文章已经生成可运行项目
### 如何使用 TransUNet 训练自定义数据集 为了使用 TransUNet 模型训练自定义数据集,特别是针对医学图像分割任务,可以遵循以下方法: #### 数据准备 确保自定义数据集按照特定结构组织。对于医学图像分割任务,通常采用类似于 VOC 或 COCO 的格式。如果数据集尚未转换为此类标准格式,则需先完成此操作。 ```bash dataset/ ├── images/ │ ├── img1.png │ └── ... └── masks/ ├── mask1.png └── ... ``` 每个图像应有一个对应的掩码文件表示目标区域[^2]。 #### 安装依赖项 安装必要的 Python 库以支持 TransUNet 和其他辅助功能。这包括但不限于 PyTorch、TensorFlow 及其扩展包 torchvision/tensorflow_datasets 等。 ```bash pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install tensorflow matplotlib scikit-image opencv-python git clone https://github.com/Beckschen/TransUNet.git cd TransUNet && pip install . ``` #### 配置环境变量与参数设置 创建配置文件 `config.yaml` 来指定超参数和其他重要选项,如批次大小、学习率等。 ```yaml # config.yaml example configuration file for training a TransUNet model. BATCH_SIZE: 8 LEARNING_RATE: 0.0001 EPOCHS: 50 IMAGE_SIZE: [256, 256] NUM_CLASSES: 2 DATASET_PATH: './path/to/dataset' MODEL_SAVE_DIR: './checkpoints/' LOGGING_LEVEL: 'INFO' ``` #### 编写训练脚本 编写一个完整的训练循环,在其中加载预处理后的数据并调用 TransUNet 进行迭代优化过程。 ```python import os from transunet import CONFIGS as configs from transunet.modeling import VisionTransformer, CONFIGS from data_loader import get_loader import torch.optim as optim import torch.nn.functional as F from tqdm import trange def train(): device = "cuda" if torch.cuda.is_available() else "cpu" # Load dataset loader and initialize the network architecture based on chosen settings from YAML dataloader = get_loader(config['DATASET_PATH'], batch_size=config['BATCH_SIZE']) config_vit = configs['ViT-B_16'] net = VisionTransformer(config_vit, num_classes=config['NUM_CLASSES']).to(device) optimizer = optim.Adam(net.parameters(), lr=float(config['LEARNING_RATE'])) criterion = F.cross_entropy epochs = int(config['EPOCHS']) with trange(epochs) as t: for epoch in t: running_loss = 0.0 for i, (inputs, labels) in enumerate(dataloader): inputs, labels = inputs.to(device), labels.to(device) outputs = net(inputs) loss = criterion(outputs, labels.long()) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() avg_loss = running_loss / len(dataloader) t.set_description(f'Epoch {epoch}, Loss={avg_loss:.4f}') if __name__ == '__main__': import yaml global config with open('config.yaml') as f: config = yaml.safe_load(f)['train'] train() ``` 上述代码展示了如何构建基于 PyTorch 实现的 TransUNet 模型训练流程,并通过 DataLoader 加载自定义的数据集来进行批量训练[^1]。
评论 569
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值