1. 划分数据集
import os
from sklearn.model_selection import train_test_split
import shutil
def remove_file(path):
if os.path.exists(path):
print(f"文件 {path} 已存在,正在删除...")
os.remove(path)
train_path = 'train.txt'
remove_file(train_path)
valid_path = 'valid.txt'
remove_file(valid_path)
test_path = 'test.txt'
remove_file(test_path)
# 设置 BraTS2021 数据集的根目录
data_root = 'dataset' # 替换为你的 BraTS2021 目录路径
# 获取所有样本的文件夹名称
subjects = [d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d))]
# 划分数据集
# 首先将数据集分为训练集和测试集(80% 训练,20% 测试)
train_subjects, test_subjects = train_test_split(subjects, test_size=0.2, random_state=42)
# 再将训练集分为训练集和验证集(80% 训练,20% 验证)
train_subjects, valid_subjects = train_test_split(train_subjects, test_size=0.2, random_state=42)
# 保存划分结果到文件
def save_to_file(file_path, data):
with open(file_path, 'w') as f:
for item in data:
f.write(f"{item}\n")
# 保存训练集、验证集和测试集
save_to_file(os.path.join(data_root, 'train.txt'), train_subjects)
save_to_file(os.path.join(data_root, 'valid.txt'), valid_subjects)
save_to_file(os.path.join(data_root, 'test.txt'), test_subjects)
print(f"训练集大小: {len(train_subjects)}")
print(f"验证集大小: {len(valid_subjects)}")
print(f"测试集大小: {len(test_subjects)}")
# 创建 train、val 和 test 文件夹
train_dir = os.path.join(data_root, 'train')
val_dir = os.path.join(data_root, 'valid')
test_dir = os.path.join(data_root, 'test')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# 将数据复制到对应的文件夹
def copy_files(subjects, source_dir, target_dir):
for subject in subjects:
src = os.path.join(source_dir, subject)
dst = os.path.join(target_dir, subject)
if os.path.exists(src):
shutil.copytree(src, dst)
print(f"复制: {src} -> {dst}")
else:
print(f"警告: 源路径不存在 {src}")
# 复制训练集
copy_files(train_subjects, data_root, train_dir)
# 复制验证集
copy_files(valid_subjects, data_root, val_dir)
# 复制测试集
copy_files(test_subjects, data_root, test_dir)
print("数据集划分和复制完成!")
2. 将.nii转为.pkl
convert the .nii files as .pkl files and realize data normalization.
import pickle
import os
import numpy as np
import nibabel as nib
modalities = ('flair', 't1ce', 't1', 't2')
datasets_path = "../../../../BraTs/BraTs2021"
# train
train_set = {
'root': datasets_path,
'flist': 'train.txt',
'has_label': True
}
# test/validation data
valid_set = {
'root': datasets_path,
'flist': 'valid.txt',
'has_label': False
}
test_set = {
'root': datasets_path,
'flist': 'test.txt',
'has_label': False
}
def nib_load(file_name):
if not os.path.exists(file_name):
print('Invalid file name, can not find the file!')
proxy = nib.load(file_name)
#
data = proxy.get_fdata()
# data = np.asanyarray(proxy)
proxy.uncache()
return data
def process_i16(path, has_label=True):
""" Save the original 3D MRI images with dtype=int16. Noted that no normalization is used! """
label = np.array(nib_load(path + 'seg.nii.gz'), dtype='uint8', order='C')
images = np.stack([
np.array(nib_load(path + modal + '.nii.gz'), dtype='int16', order='C')
for modal in modalities], -1)# [240,240,155]
output = path + 'data_i16.pkl'
with open(output, 'wb') as f:
print(output)
print(images.shape, type(images), label.shape, type(label)) # (240,240,155,4) , (240,240,155)
pickle.dump((images, label), f)
if not has_label:
return
def process_f32b0(path, has_label=True):
""" Save the data with dtype=float32.
z-score is used but keep the background with zero! """
if has_label:
label = np.array(nib_load(path + 'seg.nii.gz'), dtype='uint8', order='C')
images = np.stack([np.array(nib_load(path + modal + '.nii.gz'), dtype='float32', order='C') for modal in modalities], -1) # [240,240,155]
output = path + 'data_f32b0.pkl'
print("output=", output)
mask = images.sum(-1) > 0
for k in range(4):
x = images[..., k] #
y = x[mask]
# 0.8885
x[mask] -= y.mean()
x[mask] /= y.std()
images[..., k] = x
with open(output, 'wb') as f:
print(output)
if has_label:
pickle.dump((images, label), f)
"""
data_dict = {
'image': images,
'label': label
}
pickle.dump(data_dict, f)
为了保存和加载pkl
loaded_data_dict= pkload(path + 'data_f32b0.pkl')
sample = {'image': loaded_data_dict['image'], 'label': loaded_data_dict['label']}
sample = transform(sample)
"""
else:
pickle.dump(images, f)
if not has_label:
return
def doit(dset):
root, has_label = dset['root'], dset['has_label']
file_list = os.path.join(root, dset['flist'])
subjects = open(file_list).read().splitlines()
names = [sub.split('/')[-1] for sub in subjects]
paths = [os.path.join(root, sub, name + '_') for sub, name in zip(subjects, names)]
print("paths=", len(paths))
for path in paths:
process_f32b0(path, has_label)
if __name__ == '__main__':
doit(train_set)
doit(valid_set)
doit(test_set)