首先按照别人的经验帖配置好代码:TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)_transunet训练自己的数据集-优快云博客
我用的CSASIA-V4数据集,用什么都一样。
开始我先把我的图像压成npz格式后导入了这两个文件夹,
直接运行train.py,报错如下:
问师兄帮我解决:
注意:压成npz文件的时候先看看自己的label标签是什么,如果是如下图和我一样二分类,我的类别是0和255,即黑(0)白(255)色,就要在压成npz的时候改成0,1。类别多的同理改成0开始的连续数字,下面是我的制作npz代码
注意最关键的 :label[label > 0] = 1 把大于0的设置为1
import glob
import cv2
import numpy as np
def npz():
# 图像路径
path = r'D:\Datasets_test\CASIA-Interval-V4\test\image\*.jpg'
# 项目中存放训练所用的npz文件路径
path2 = r'D:\ProfessionalSoftware\python\Iris_Segmentation\project_TransUNet\data\Synapse\test_vol_h5\\'
for i, img_path in enumerate(glob.glob(path)):
# 读入图像
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 读入标签
label_path = img_path.replace('image', 'gt').replace('.jpg', '.tiff')
label = cv2.imread(label_path, flags=0)
label[label > 0] = 1
# 保存npz
np.savez(path2 + str(i), image=image, label=label)
print('------------', i)
# 加载npz文件
# data = np.load(r'G:\dataset\Unet\Swin-Unet-ori\data\Synapse\train_npz\0.npz', allow_pickle=True)
# image, label = data['image'], data['label']
print('ok')
# 下面的输出都是从 distributed_utils.py 而来
if __name__ == '__main__':
npz()
压成npz成功以后,还要读取文件名,代码如下:
import os
def write_filenames(folder_path, output_file):
# 获取文件夹中的所有文件
filenames = os.listdir(folder_path)
# 打开输出文件以写入
with open(output_file, 'w') as f:
for filename in filenames:
# 获取文件名(不带扩展名)
name_without_extension = os.path.splitext(filename)[0]
f.write(name_without_extension + '\n')
print(f'文件名已保存至 {output_file}')
if __name__ == '__main__':
# 输入文件夹路径
folder_path = r'D:\ProfessionalSoftware\python\Iris_Segmentation\project_TransUNet\data\Synapse\test_vol_h5'
# 输出的文件路径
output_file = r'D:\ProfessionalSoftware\python\Iris_Segmentation\project_TransUNet\TransUNet\lists\lists_Synapse\test_vol.txt'
write_filenames(folder_path, output_file)
开始测试train.py。报错如下:
下面仔细检查train.py和test.py文件,确保里面所有参数都是一样的!如果输入的图片不是正方形就写最长边,例如我是320*280,设置 img_size = 320。
运行train.py,运行成功。
成功预测: