图像 shape: (128, 448, 448)
标签 shape: (128, 448, 448)
--- 图像统计 ---
最小值: 0
最大值: 1162
均值: 52.04410152046048
标准差: 52.50468605606327
是否可能归一化: False
是否包含负值: False
--- 标签统计 ---
唯一值: [0. 1.]
是否二值标签: True
前景(1)占比: 0.003764444467972736
中间层切片前景像素数: 880这是我的数据信息。我的数据处理部分如代码所示,是否正确:import os
import torch
import numpy as np
from glob import glob
from torch.utils.data import Dataset
import h5py
import itertools
from torch.utils.data.sampler import Sampler
class BraTS2019(Dataset):
""" BraTS2019 Dataset """
def __init__(self, base_dir=None, split='train', num=None, transform=None):
self._base_dir = base_dir
self.transform = transform
self.sample_list = []
# train_path = self._base_dir+'/train_75phase_56people.txt'
test_path = self._base_dir+'/datanewnew/val.txt'
#修改
train_path = self._base_dir + '/datanewnew/train.txt'
if split == 'train':
with open(train_path, 'r') as f:#原代码
self.image_list = f.readlines()#原代码
######虚拟数据测试代码
# self.image_list = [
# 'path/to/dummy_image_1',
# 'path/to/dummy_image_2',
# 'path/to/dummy_image_3',
# 'path/to/dummy_image_4',
# 'path/to/dummy_image_5',
# 继续添加更多虚拟数据路径
# ]
# 这里是处理虚拟数据的部分,原来是从文件读取路径,现在用虚拟数据替代
# self.image_list = f.readlines() # 这一行被替换掉了
# 通过手动定义的虚拟路径处理
#self.image_list = [item.replace('\n', '').split(",")[0] for item in self.image_list]
##3#### 虚拟数据测试代码结束
##虚拟数据结束
elif split == 'test':
with open(test_path, 'r') as f:
self.image_list = f.readlines()
# 确保 image_list 中的数据量满足 num(例如,num = 5)
#这里也是假数据
if num is not None:
self.image_list = self.image_list[:num] # 限制数据量
#假数据结束
print("Total samples:", len(self.image_list)) # 打印数据集样本数量
#self.image_list = [item.replace('\n', '').split(",")[0] for item in self.image_list]
#if num is not None:
#self.image_list = self.image_list[:num]
#print("total {} samples".format(len(self.image_list)))
def __len__(self):
return len(self.image_list)
def __getitem__(self, index): # getitem接收一个index,然后返回图片数据和标签
# image_name = self.image_list[index] # 对于有标签和无标签数据,能不能在这里做一个判断?start with unlabeled?
# h5f = h5py.File(self._base_dir + "/data/{}.h5".format(image_name), 'r')
# image = h5f['image'][:]
# label = h5f['label'][:]
# sample = {'image': image, 'label': label.astype(np.uint8)}
# if self.transform:
# sample = self.transform(sample) # I guess 'sample' include image and label
# return sample
#修改
if index >= len(self.image_list):
raise IndexError("Index out of range: {}".format(index))
image_name = self.image_list[index]
image_name = image_name.strip() # 去除可能的换行符和空格
h5f = h5py.File(self._base_dir + "/datanewnew/{}.h5".format(image_name), 'r')
img1 = h5f['image'][:]# image = h5f['image'][:]
img2 = h5f['mip'][:]
img1 = (img1 - 52.04) / 52.50 # 新数据集: mean=52.04, std=52.50
img2 = (img2 - 52.04) / 52.50
image = np.stack([img1, img2], axis=0) # 形状变为 (2, D, H, W)
#label = h5f['label'][:]原代码
#为新增数据加的代码
if 'label' in h5f:
label = h5f['label'][:] # (112, 272, 144)
label = np.expand_dims(label, axis=0) # → (1, 272, 112, 144)
else:
label = np.zeros_like(img1, dtype=np.uint8)
label = np.expand_dims(label, axis=0) # → (1, 272, 112, 144)
sample = {'image': image, 'label': label.astype(np.uint8)}
if self.transform:
sample = self.transform(sample)
print("标签中的唯一值:", np.unique(label))
return sample
class CenterCrop(object):
def __init__(self, output_size):
self.output_size = output_size
def __call__(self, sample):
image, label = sample['image'], sample['label'] # image: (2, D, H, W)
print("Image shape before padding:", image.shape)
print("Label shape before padding:", label.shape)
# pad the sample if necessary
if label.shape[1] < self.output_size[0] or label.shape[2] < self.output_size[1] or label.shape[3] < \
self.output_size[2]:
pw = max((self.output_size[2] - label.shape[3]) // 2 + 3, 0)
ph = max((self.output_size[1] - label.shape[2]) // 2 + 3, 0)
pd = max((self.output_size[0] - label.shape[1]) // 2 + 3, 0)
print(f"Padding dimensions: pw={pw}, ph={ph}, pd={pd}")
# pad both image and label
image = np.pad(image, [(0, 0), (pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
label = np.pad(label, [(0, 0), (pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
# 确保 image 和 label 的维度一致
assert image.shape == label.shape, f"Image and label shape mismatch: {image.shape} vs {label.shape}"
_, w, h, d = image.shape
w1 = int(round((w - self.output_size[2]) / 2.))
h1 = int(round((h - self.output_size[1]) / 2.))
d1 = int(round((d - self.output_size[0]) / 2.))
label = label[:, w1:w1 + self.output_size[2], h1:h1 + self.output_size[1], d1:d1 + self.output_size[0]]
image = image[:, w1:w1 + self.output_size[2], h1:h1 + self.output_size[1], d1:d1 + self.output_size[0]]
return {'image': image, 'label': label}
class RandomCrop(object):
def __init__(self, output_size, with_sdf=False):
self.output_size = output_size
self.with_sdf = with_sdf
def __call__(self, sample):
image, label = sample['image'], sample['label'] # image: (2, D, H, W)
print("Image shape before padding:", image.shape)
print("Label shape before padding:", label.shape)
if self.with_sdf:
sdf = sample['sdf']
# pad the sample if necessary
if label.shape[1] < self.output_size[0] or label.shape[2] < self.output_size[1] or label.shape[3] < \
self.output_size[2]:
pw = max((self.output_size[2] - label.shape[3]) // 2 + 3, 0)
ph = max((self.output_size[1] - label.shape[2]) // 2 + 3, 0)
pd = max((self.output_size[0] - label.shape[1]) // 2 + 3, 0)
print(f"Padding dimensions: pw={pw}, ph={ph}, pd={pd}")
# pad both image and label
image = np.pad(image, [(0, 0), (pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
label = np.pad(label, [(0, 0), (pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
# 确保 image 和 label 的维度一致
assert image.shape == label.shape, f"Image and label shape mismatch: {image.shape} vs {label.shape}"
if self.with_sdf:
sdf = np.pad(sdf, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
_, w, h, d = image.shape
w1 = np.random.randint(0, w - self.output_size[2] + 1)
h1 = np.random.randint(0, h - self.output_size[1] + 1)
d1 = np.random.randint(0, d - self.output_size[0] + 1)
label = label[:, w1:w1 + self.output_size[2], h1:h1 + self.output_size[1], d1:d1 + self.output_size[0]]
image = image[:, w1:w1 + self.output_size[2], h1:h1 + self.output_size[1], d1:d1 + self.output_size[0]]
if self.with_sdf:
sdf = sdf[w1:w1 + self.output_size[2], h1:h1 + self.output_size[1], d1:d1 + self.output_size[0]]
return {'image': image, 'label': label, 'sdf': sdf}
else:
return {'image': image, 'label': label}
class RandomRotFlip(object):
"""
Crop randomly flip the dataset in a sample
Args:
output_size (int): Desired output size
"""
def __call__(self, sample):
image, label = sample['image'], sample['label'] # image: (2, D, H, W), label: (1, D, H, W)
# 随机旋转
k = np.random.randint(0, 4)
image = np.stack([np.rot90(im, k, axes=(1, 2)) for im in image], axis=0)
label = np.rot90(label, k, axes=(2, 3)) # 旋转 H-W 平面
# 随机翻转
axis = np.random.randint(0, 3)
image = np.flip(image, axis=axis + 1).copy() # 通道维度 axis=0,+1 代表空间维度
label = np.flip(label, axis=axis + 1).copy()
return {'image': image, 'label': label}
class RandomNoise(object):
def __init__(self, mu=0, sigma=0.1):
self.mu = mu
self.sigma = sigma
def __call__(self, sample):
image, label = sample['image'], sample['label']
noise = np.clip(self.sigma * np.random.randn(
*image.shape), -2*self.sigma, 2*self.sigma)
noise = noise + self.mu
image = image + noise
return {'image': image, 'label': label}
class CreateOnehotLabel(object):
def __init__(self, num_classes):
self.num_classes = num_classes
def __call__(self, sample):
image, label = sample['image'], sample['label']
onehot_label = np.zeros(
(self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32)
for i in range(self.num_classes):
onehot_label[i, :, :, :] = (label == i).astype(np.float32)
return {'image': image, 'label': label, 'onehot_label': onehot_label}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image = sample['image']
#image = image.reshape(
# 1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)
image = image.astype(np.float32)
if 'onehot_label' in sample:
return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(),
'onehot_label': torch.from_numpy(sample['onehot_label']).long()}
else:
return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()}
class TwoStreamBatchSampler(Sampler):
"""Iterate two sets of indices
An 'epoch' is one iteration through the primary indices.
During the epoch, the secondary indices are iterated through
as many times as needed.
"""
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
self.primary_indices = primary_indices
self.secondary_indices = secondary_indices
self.secondary_batch_size = secondary_batch_size
self.primary_batch_size = batch_size - secondary_batch_size
assert len(self.primary_indices) >= self.primary_batch_size > 0
# assert len(self.secondary_indices) >= self.secondary_batch_size > 0
#修改
def __iter__(self):
primary_iter = iterate_once(self.primary_indices)
secondary_iter = iterate_eternally(self.secondary_indices)
return (
primary_batch + secondary_batch
for (primary_batch, secondary_batch)
in zip(grouper(primary_iter, self.primary_batch_size),
grouper(secondary_iter, self.secondary_batch_size))
)
def __len__(self):
return len(self.primary_indices) // self.primary_batch_size
def iterate_once(iterable):
return np.random.permutation(iterable)
def iterate_eternally(indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())
def grouper(iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)
最新发布