复现TransUnet官方代码用于训练自己数据集-超详细实操和debug记录

该文章已生成可运行项目,

一、制作自己的分割数据集

参考:安装打标签工具并打标签,在我的博客里自行检索即可
具体的标注软件使用的是labelme,安装过程和使用方法大家可以看链接给出的另一篇文章,我这里把主要流程说一下。

  1. 用labelme对数据进行标注,每张图片标注保存后会对应生成一个.json文件,如下图
    在这里插入图片描述
  2. 编写代码利用json文件生成标签图
  3. 最后将图像image和标签图label使用相同的命名保存到两个文件夹中
  4. 另存一份上述文件夹,并添加后缀deleted,即image_deleted和label_deleted,然后根据label_deleted逐个删除不需要的图片,再利用代码根据label_deleted中删除后的图片列表来删除image_deleted,从而生成一组匹配的图片和标签。

二、下载官方TransUnet代码并加载自己的数据

1. 首先下载官方代码并解压

代码地址如下:https://github.com/Beckschen/TransUNet

https://github.com/Beckschen/TransUNet

在这里插入图片描述

2. 配置Conda的pytroch环境

本次测试使用的环境是python 3.7.5 torch 1.9.0
在这里插入图片描述
程序文件架构如下:
在这里插入图片描述

3. 制备TransUnet所需数据

由于TransUnet源码所需的数据格式为npz,因此需要通过代码首先将数据转换为npz格式。
在根目录下创建prepareddata文件夹,并在其中创建image_data3_crop_deleted和label_data3_crop_deleted文件夹、以及图像转化工具getnpz_test.py和getnpz_train.py 还有列表生成工具write_test_txt.py和write_train_txt.py

首先,转换代码会把image和label成对生成在一个npz中,并将npz保存在根目录的 data/Synapse/test_vol_h5/文件夹下。

‘…/data/Synapse/test_vol_h5/’
这里,我们默认通过getnpz_test.py和getnpz_train.py对全部的样本生成到Synapse文件夹下,不进行训练集和测试集的划分操作。 交给后面的txt列表生成。

测试集npz生成代码如下:

import glob
import numpy as np
import cv2
import logging
import pdb
import logging
import os
import sys
from PIL import Image
import torchvision.transforms as transforms

def npz():

    #配置日志文件
    exp = 'test_dataname'
    log_folder = './datacreate_log'
    os.makedirs(log_folder, exist_ok=True)
    logging.basicConfig(filename=log_folder + '/'+exp+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))

    #原图像路径
    path = '/Users/XXX/Documents/codespace/TransUNet-main-xxx-data3/xxxpreparedata/image_data3_crop_deleted/*.png'

    #项目中存放训练所用的npz文件路径
    # path2 = '../data/Synapse/train_npz/'
    path2 = '../data/Synapse/test_vol_h5/'

    print(glob.glob(path))
    for i,img_path in enumerate(glob.glob(path)):
        #读入图像
        image = cv2.imread(img_path) # cv2读图片,读进来默认是BGR格式的
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) # BGR转RGB格式 三通道
        # image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)

        #读入标签
        # print(img_path)
        # label_path = img_path.replace('images','labels')
        label_path = img_path.replace('image_data3_crop_deleted','label_data3_crop_deleted')
        # flags: 标志位, 表示读取数据的格式,读取彩色可以设为cv2.IMREAD_COLOR(flags=1)
        # ,读取灰度图像设为cv2.IMREAD_GRAYSCALE(flags=0),读取原始图像设为cv2.IMREAD_UNCHANGED(flags=-1)。
        label = cv2.imread(label_path,flags=0) # 读取灰度图像(0-255)

        # pdb.set_trace()
        #将非目标像素设置为0
        label[label!=38]=0 #根据你自己的数据修改,这里我只有labelme标注的一类
        #将目标像素设置为1
        label[label==38]=1
        #保存npz
        #image:(256, 256, 3)  label:(256,256)
        np.savez(path2+str(i),image=image,label=label) #三通道的图,单通道的标签
        # print('------------',i)
        logging.info('test'+img_path+'----'+str(i))
    # 加载npz文件
    # data = np.load(r'G:\dataset\Unet\Swin-Unet-ori\data\Synapse\train_npz\0.npz', allow_pickle=True)
    # image, label = data['image'], data['label']

    print('ok')

npz()

报错:

FileNotFoundError: [Errno 2] No such file or directory: '../data/Synapse/test_vol_h5/0.npz'

说明我们需要在Synapse下手动创建data/test_vol_h5这个文件夹
再次运行结束后,显示一个OK和所有生成图像的对应标签列表,则说明运行完成。其中,原图片文件和标签的名字会被修改后存入npz,因此,我们设计了一个记录器来存储npz文件和标签文件的对应关系。 文件位置在datacreate_log/train_dataname.txt

然后是运行getnpz_train.py,

import glob
import numpy as np
import cv2
import logging
import pdb
import logging
import os
import sys
from PIL import Image
import torchvision.transforms as transforms

def npz():

    #配置日志文件
    exp = 'train_dataname'
    log_folder = './datacreate_log'
    os.makedirs(log_folder, exist_ok=True)
    logging.basicConfig(filename=log_folder + '/'+exp+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))

    #原图像路径
    path = '/Users/XXX/Documents/codespace/TransUNet-main-xxx-data3/xxxpreparedata/image_data3_crop_deleted/*.png'

    #项目中存放训练所用的npz文件路径
    path2 = '../data/Synapse/train_npz/'

    print(glob.glob(path))
    for i,img_path in enumerate(glob.glob(path)):
        #读入图像
        image = cv2.imread(img_path) # cv2读图片,读进来默认是BGR格式的
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) # BGR转RGB格式 三通道
        # image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)

        #读入标签
        # print(img_path)
        # label_path = img_path.replace('images','labels')
        label_path = img_path.replace('image_data3_crop_deleted','label_data3_crop_deleted')
        # flags: 标志位, 表示读取数据的格式,读取彩色可以设为cv2.IMREAD_COLOR(flags=1)
        # ,读取灰度图像设为cv2.IMREAD_GRAYSCALE(flags=0),读取原始图像设为cv2.IMREAD_UNCHANGED(flags=-1)。
        label = cv2.imread(label_path,flags=0) # 读取灰度图像(0-255)

        # pdb.set_trace()
        #将非目标像素设置为0
        label[label!=38]=0 #根据你自己的数据修改,这里我只有labelme标注的一类
        #将目标像素设置为1
        label[label==38]=1
        #保存npz
        #image:(256, 256, 3)  label:(256,256)
        np.savez(path2+str(i),image=image,label=label) #三通道的图,单通道的标签
        # print('------------',i)
        logging.info('train'+img_path+'----'+str(i))
    # 加载npz文件
    # data = np.load(r'G:\dataset\Unet\Swin-Unet-ori\data\Synapse\train_npz\0.npz', allow_pickle=True)
    # image, label = data['image'], data['label']

    print('ok')
npz()

仍然报

FileNotFoundError: [Errno 2] No such file or directory: '../data/Synapse/train_npz/0.npz'

因此需要手动创建文件夹data/Synapse/train_npz/
对应的列表信息也保存在datacreate_log/train_dataname.txt中;
对比方式为:
在这里插入图片描述
这里说明,2114.png对应1.npz。也就是,test_vl.txt中的索引6下的1.npz对应2114.png
在这里插入图片描述

4. 生成数据列表

在这里插入图片描述
我们可以看到,原文件中需要的数据已经准备好了,但test_vol.txt以及train.txt中的内容,还是作者给出的,需要替换掉;正常我们是直接利用test_vol_h5和train_npz中的文件名来生成,但是,我们在上一步并没有对数据进行划分,因此在这里生成list的时候,就需要进行划分了。因为最终程序也是通过list来读文件的,上一步不划分并不会产生影响。

我们,找到train.txt和test_vol.txt,手动将文件里面的内容清空(train.txt不清空也行,后面的代码会直接将这个文件覆盖掉),split_data.py这个文件直接无视。自己写一个函数读取train_npz中所有的文件名称,然后将文件名称写入train.txt文件,一个名称一行,如下图所示。同理可完成test_vol.txt文件制作。(这里test文件,直接从train.txt中选一些复制到test_vol.txt中即可,无需代码生成)
我们在preparedata文件夹下创建wtire_train_txt.py

import glob

def write_name():
    #npz文件路径
    files = glob.glob('../data/Synapse/train_npz/*.npz')
    # files = glob.glob('../data/Synapse/test_vol_h5/*.npz')

    #txt文件路径
    f = open('../lists/lists_Synapse/train.txt','w')
    # f = open('../lists/lists_Synapse/test_vol.txt','w')

    for i in files:
        name = i.split('/')[-1]
        name = name[:-4]+'\n'
        f.write(name) # 按照读取文件的名字,存到train.txt列表中,然后从里面复制出来0.1样本,现在问题是,文件名和图片怎么对应
write_name()

在这里插入图片描述
从train.txt中选10%作为test_vol.txt,这里已经筛选了。

数据集制作完毕!!!代码会先去train.txt文件中读取训练样本的名称,然后根据名称再去train_npz文件夹下读取npz文件。所以每一步都很重要,必须正确!

三、TransUnet模型训练

1.下载官方预训练权重

预训练权重官方下载地址
在这里插入图片描述
在这里插入图片描述
下载上图所示这个,下载后放到如下文件夹中,这组文件夹都需要我们自己创建:
在这里插入图片描述

2.修改数据读取文件的代码

找到datasets/dataset_synapse.py文件中的Synapse_dataset类,修改__getitem__函数。

class Synapse_dataset(Dataset):
    def __init__(self, base_dir, list_dir, split, transform=None):
        self.transform = transform  # using transform in torch!
        self.split = split
        self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
        self.data_dir = base_dir

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        if self.split == "train":
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir + "/" + slice_name + '.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir + "/" + slice_name + '.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
            image = torch.from_numpy(image.astype(np.float32))
            image = image.permute(2, 0, 1)
            label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample

找到找到datasets/dataset_synapse.py文件中的RandomGenerator类,修改__call__函数。

class RandomGenerator(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        x, y,_ = image.shape
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3)  # why not 3?
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        image = torch.from_numpy(image.astype(np.float32))
        image = image.permute(2,0,1)
        label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label.long()}
        return sample

至此,与数据相关的代码都改好啦。

3.调整训练参数

  1. train.py文件中的修改部分
    在这里插入图片描述在这里插入图片描述
    在这里插入图片描述
  2. trainer.py文件中的修改部分
    其中trainer.py中的num_worker=8可以不改,只要你的设备支持的话。
    在这里插入图片描述
    OK! 至此 模型的修改基本上就完成了,可以开始训练啦,但是训练的过程中,大多数人会出错,本文第四部分会列举我在实验过程中遇到过的错误以及解决方案,看看下面会不会有和你的相似的错误。

运行代码train.py,在你自己的设备上开始查错纠错吧!!!!

## 集群运行代码

四、开始训练遇到的bug以及修改记录

注意,每个人的基础实验环境不同,遇到的bug也各不相同。并不是每个都会遇到,也并不是出了这里记录的没有别的bug,有其他bug大家可以在评论区讨论。

1. 缺失包

  1. No module named ‘ml_collections’
pip install ml_collections
  1. No module named ‘tensorboardX’
pip install tensorboardX
  1. No module named ‘medpy’
pip install medpy

2. 没有Cuda环境

AssertionError: Torch not compiled with CUDA enabled

修改原代码中的.cuda()部分,按照如下内容的例子全部改掉,有多处

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 改
    net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).to(device) #改

第一处:
在这里插入图片描述

第二处:
在这里插入图片描述

3. FileNotFoundError: [Errno 2] No such file or directory: ‘…/model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz’

FileNotFoundError: [Errno 2] No such file or directory: ‘…/model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz’

解决方案:在vit_seg_configs.py里面把路径修改一下,就是前面两个点…/改成一个点./ ; 注意,找对方法,别乱找;
在这里插入图片描述

五、训练模型

当出现下面的文字或者界面的时候,就说明已经跑起来了。
在这里插入图片描述
我们慢慢等,看看会不会出问题,如果没问题,下面继续进行测试部分;

哎发现问题了。

IndexError: index 1 is out of bounds for dimension 0 with size 1

出现该问题的原因是,主要原因就是数据(样本)的个数是奇数,而batch_size开的偶数,因此[1]这个batch_size的位置有问题了,越界了,应该改成0。

修改的部分如下:
在这里插入图片描述
在这里插入图片描述
模型训练完成后,会按照代码中的设置保存训练模型的参数文件,和训练日志:
在这里插入图片描述
默认是保存最后一次的训练结果,和每50个epoch保存一次的训练结果。

六、预测

  1. 修改test.py文件
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

在这里插入图片描述

  1. 修改utils.py文件
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    _,x, y = image.shape
    if x != patch_size[0] or y != patch_size[1]:
        #缩放图像符合网络输入
        image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input = torch.from_numpy(image).unsqueeze(0).float().to(device)
    net.eval()
    with torch.no_grad():
        out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
        out = out.cpu().detach().numpy()
        if x != patch_size[0] or y != patch_size[1]:
            #缩放图像至原始大小
            prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
        else:
            prediction = out

    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(prediction == i, label == i))

    if test_save_path is not None:
        a1 = copy.deepcopy(prediction)
        a2 = copy.deepcopy(prediction)
        a3 = copy.deepcopy(prediction)

        a1[a1 == 1] = 0
        a2[a2 == 1] = 255
        a3[a3 == 1] = 0

        a1 = Image.fromarray(np.uint8(a1)).convert('L')
        a2 = Image.fromarray(np.uint8(a2)).convert('L')
        a3 = Image.fromarray(np.uint8(a3)).convert('L')
        prediction = Image.merge('RGB', [a1, a2, a3])
        prediction.save(test_save_path+'/'+case+'.png')
    return metric_list

在这里插入图片描述
最终根据预测结果和索引列表,查原图和标签,做对比。 也可用来组合生成掩码图。

本文章已经生成可运行项目
评论 11
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

五阿哥爱跳舞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值