nnUNet V2代码——图像增强(一)

nnUNet V2使用的图像增强方法

nnUNet V2会依照概率依次对图像应用以下图像增强方法:

代码-类名对应图像增强方法
SpatialTransform弹性形变仿射变换、裁剪
GaussianNoiseTransform高斯噪声
GaussianBlurTransform高斯模糊
MultiplicativeBrightnessTransform亮度
ContrastTransform对比度
SimulateLowResolutionTransform低分辨率
GammaTransform伽马矫正
MirrorTransform镜像

此外,还有一些辅助方法,部分是3D级联专属方法:

1️⃣Convert3DTo2DTransform 和 Convert2DTo3DTransform:

对于3d配置,如果图像(data)有各向异性的话,将c * x * y * z转换为(c * x) * y * z,其中x是具有各向异性的那个轴,3D变2D,在完成仿射变换、弹性变形 、裁剪 后,会转换回3D。如果直接对各向异性的数据应用三维数据增强,可能会导致深度方向的信息丢失或过度拉伸/压缩。

2️⃣MaskImageTransform:

提取seg中值为-1的无效区域,依照无效区域,将图像(data)中对应区域的值设置为0(在此之前,图像已经进行了标准化)。

nnU-Net V2将seg部分值设置为-1有两种情况:

  1. 预处理阶段,nnU-Net V2依照图像(data)的无效区域,将segmentation 的对应区域值设置为-1.
  2. 在DataLoad产生patch过程中,nnU-Net V2将patch_segmentation 填充的pad区域赋值为-1:

MaskImageTransform 会在use_mask_for_normTrue时,将上述seg区域的值设置为0。

预处理阶段,nnU-Net 如果通过裁剪非零区域显著减小了图像尺寸,则会将 use_mask_for_norm 设置为 True

本方法的意义在预处理阶段说明过。

3️⃣RemoveLabelTansform:将seg中值为-1的区域赋值为0,与 MaskImageTransform 的区别在于本方法对 segmentation 操作

4️⃣ConvertSegmentationToRegionsTransform:将segmentation中的类别转换为bool类型。

5️⃣DownsampleSegForDSTransform:

对 segmentation 进行下采样,以适配深度监督的需求。

图像增强是nnU-Net V2 数据加载(DataLoad)的一部分,nnU-Net V2通过多进程管理数据加载和训练,在这里计算出深度监督所需的各层次segmentation,提高训练速度,减少GPU等待时间。

6️⃣MoveSegAsOneHotToDataTransform:将 segmentation 转换为 one-hot 编码后,将其作为额外的输入通道添加到图像(data)中。

7️⃣ApplyRandomBinaryOperatorTransform:对图像数据中的特定通道应用随机的二值形态学操作(如膨胀、腐蚀、开运算、闭运算等)。

8️⃣RemoveRandomConnectedComponentFromOneHotEncodingTransform:从 one-hot 编码的分割标签中随机移除连通区域。

接下来按照顺序依次介绍各个方法的代码

各个图像增强代码

安装batchgeneratorsv2,nnU-Net V2关于图像增强的代码都在这个库中,点击链接,将其clone到本地后,在命令行进入文件夹内,pip install -e . 即可(注意-e后有个点)。

1. BasicTransform

图像增强变换的父类,代码在batchgeneratorsv2 \ transforms \ base \ basic_transform.py文件中

在数据增强过程中,首先调用self.__call__函数,再调用self.get_parameters函数,最后调用self.apply函数,完成数据增强。

self.apply函数通过调用self._apply_to_image 函数、self._apply_to_segmentation 函数将self.get_parameters得到的数据增强参数应用到图像seg上。

传入图像增强的图像和seg大小是 ( C , X , Y ) (C, X, Y) (C,X,Y) or ( C , X , Y , Z ) (C, X, Y, Z) (C,X,Y,Z)

2. SpatialTransform

该类包含仿射变换、弹性形变、裁剪三个操作,代码在batchgeneratorsv2 \ transforms \ spatial \ spatial.py文件中

__init __函数

定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用

get_parameters 函数

首先获取图像维度(dim变量),依据旋转概率、放缩概率、弹性形变概率确定本次是否对图像和seg使用旋转操作、放缩操作、弹性形变操作。

dim = data_dict['image'].ndim - 1

do_rotation = np.random.uniform() < self.p_rotation # nnU-Net V2设置的是0.2
do_scale = np.random.uniform() < self.p_scaling		# nnU-Net V2设置的是0.2
do_deform = np.random.uniform() < self.p_elastic_deform		# nnU-Net V2设置的是0

接下来确定仿射变换的参数。

确定各个轴的旋转角度、放缩尺度。sample_scalar函数在代码段内说明。

if do_rotation:
    angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(0, 3)]
else:
    angles = [0] * dim
if do_scale:
    if np.random.uniform() <= self.p_synchronize_scaling_across_axes:	# nnU-Net V2设置的是1,100%各轴同步放缩
        scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
    else:
        scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(0, 3)]
else:
    scales = [1] * dim
"""
sample_scalar函数会根据给定的数据类型,返回对应数据:
如果给定的是int类型或者float类型,返回它本身
if isinstance(scalar_type, (int, float)):
	return scalar_type
如果给定的是list类型或者tuple类型,说明给定的是一个范围
判定该范围是否合规:长度为2,[a,b]中a <= b
如果a == b,返回a;否则随机返回范围内的一个实数
elif isinstance(scalar_type, (list, tuple)):
    assert len(scalar_type) == 2, 'if list is provided, its length must be 2'
    assert scalar_type[0] <= scalar_type[1], 'if list is provided, first entry must be smaller or equal than second entry, ' \
                                            'otherwise we cannot sample using np.random.uniform'
    if scalar_type[0] == scalar_type[1]:
        return scalar_type[0]
    return np.random.uniform(*scalar_type)
如果给定的是一个函数,就调用它,传入对应参数 
elif callable(scalar_type):
    return scalar_type(*args, **kwargs)
"""

依据旋转、放缩参数确定仿射变换矩阵:

# affine matrix
if do_scale or do_rotation:
    if dim == 3:
        affine = create_affine_matrix_3d(angles, scales)	
    elif dim == 2:
        affine = create_affine_matrix_2d(angles[-1], scales)
    else:
        raise RuntimeError(f'Unsupported dimension: {dim}')
else:
    affine

接下来确定弹性形变的参数。nnU-Net V2代码中并未使用弹性形变,self.p_elastic_deform=0。我猜测是弹性形变无法应用于三维图像导致的

如果本次图像增强有弹性形变,就先确定各轴变形尺度(同步 or 异步,和上面放缩一样)。

if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:	# nnU-Net V2 设置的同步概率为0。
    deformation_scales = [
        sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)
    ] * dim
else:
    deformation_scales = [
        sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
        for i in range(dim)
    ]

生成卷积核,用于平滑随机位移场:

sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]

再生成每个轴的缩放因子:

magnitude = [
    sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
                  dim=i, deformation_scale=deformation_scales[i])
    for i in range(dim)
]

生成随机位移场,offsets各轴负责图像各轴的位移:

offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))

对于每一个轴,用高斯模糊平滑随机位移场,之后将形变幅度限制在合理范围 :

for d in range(dim):
    # 平滑
    tmp = np.fft.fftn(offsets[d].numpy())
    tmp = fourier_gaussian(tmp, sigmas[d])
    offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
    # 限制形变幅度
    mx = torch.max(torch.abs(offsets[d]))
	offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))

最后调整随机位移场的维度,用于后续grid_sample函数计算。

spatial_dims = tuple(list(range(1, dim + 1)))
offsets = torch.permute(offsets, (*spatial_dims, 0))	# offsets.shape = H*W*2

接下来确定裁剪的参数。nnU-Net V2训练时有两处裁剪,一处是数据加载(dataload),另一处是数据增强。数据加载处按照预处理阶段确定的裁剪中心和patch size进行裁剪,是一种随机裁剪;数据增强处不是随即裁剪,而是以图像正中心和patch size进行裁剪,参数self.random_crop=False。一个batch中1/3不是随机裁剪这个特性是在数据加载处实现的。

获取图像大小,计算裁剪中心:

shape = data_dict['image'].shape[1:]	# 获取图像大小(H * W 或者 x * y * z)
if not self.random_crop:	# 不随即裁剪的话,裁剪中心就是图像中心
    center_location_in_pixels = [i / 2 for i in shape]
else:						# 随机裁剪
    center_location_in_pixels = []
    for d in range(0, 3):	
    """
    这里的维度固定是3,与上面shape的维度在处理2d图像时不一样
    考虑到nnU-Net V2代码并未使用随即裁剪,self.random_crop=False
    我猜测nnU-Net V2作者还没有优化到这里,跳过
    """
        mn = self.patch_center_dist_from_border[d]
        mx = shape[d] - self.patch_center_dist_from_border[d]
        if mx < mn:
            center_location_in_pixels.append(shape[d] / 2)
        else:
            center_location_in_pixels.append(np.random.uniform(mn, mx))

函数最后返回确定的参数:

return {
    'affine': affine,
    'elastic_offsets': offsets,
    'center_location_in_pixels': center_location_in_pixels
}

_apply_to_image 函数

首先是没有仿射变换弹性形变,直接裁剪的情况。

# if params['affine'] is None and params['elastic_offsets'] is None:
# 先选取填充模式
if self.padding_mode_image == 'reflection':
    pad_mode = 'reflect'
    pad_kwargs = {}
elif self.padding_mode_image == 'zeros':
    pad_mode = 'constant'
    pad_kwargs = {'value': 0}
elif self.padding_mode_image == 'border':
    pad_mode = 'replicate'
    pad_kwargs = {}
else:
    raise RuntimeError('Unknown pad mode')
# 之后根据裁剪中心,裁剪、填充即可
img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']], self.patch_size, pad_mode=pad_mode,
                    pad_kwargs=pad_kwargs)
return img

如果仿射变换弹性形变,先进行弹性形变,再进行仿射变换,最后裁剪:

# else
# 创建一个以中心为原点的网格,每一组元素都是一个坐标
grid = _create_centered_identity_grid2(self.patch_size)
"""
def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor:
	## 按照size各维度范围,生成多个一维数组,例如size=[5, 4]
	## space[0] = 从-2到2之间均匀采样5个点,其余维度同理
    space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size]
    grid = torch.meshgrid(space, indexing="ij")
    ## 沿着坐标维度拼接,组成坐标
    grid = torch.stack(grid, -1)
    return grid
"""

# we deform first, then rotate
# 如果存在随机位移场,将其加到网格上
if params['elastic_offsets'] is not None:
    grid += params['elastic_offsets']
# 如果存在仿射变换矩阵,将其应用到网格上
if params['affine'] is not None:
    grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())

接下来如果使用中心弹性形变,需要计算均值,表示网格的几何中心 :

if self.center_deformation and params['elastic_offsets'] is not None:
    mn = grid.mean(dim=list(range(img.ndim - 1)))
else:
    mn = 0

接下来根据 get_parameters 函数中设置的裁剪中心点和图像大小,计算出裁剪中心点的像素坐标:

## 裁剪中心点的像素坐标
new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
## 在应用弹性形变和仿射变换后,网格的几何中心可能会发生变化
## 为了确保裁剪中心点 new_center 在形变后的图像中仍然位于正确的位置
## 我们需要将网格的几何中心调整到 new_center。
grid += (new_center - mn)

最后使用grid_sample对图像采样:

## 有图像img,有各个像素点采样前后坐标点(grid索引是原有坐标点,索引对应的值是采样后的坐标点)
## 采样后不在网格上的像素值通过双线性插值计算出在网格上的像素值
return grid_sample(img[None], _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
                    mode='bilinear', padding_mode=self.padding_mode_image, align_corners=False)[0]
"""
_convert_my_grid_to_grid_sample_grid函数会将grid归一化
(-1, -1) (-1, 1) (1, -1) (1, 1)对应img的四个角
"""

_apply_to_segmentation 函数

前半部分代码和 _apply_to_image 函数一样,一直到grid += (new_center - mn)

## 归一化grid
grid = _convert_my_grid_to_grid_sample_grid(grid, segmentation.shape[1:])

## 如果segment选用最近邻插值
## 代码和_apply_to_image 函数函数一致
## 差别只在将segment转换为float
if self.mode_seg == 'nearest':
    result_seg = grid_sample(
                    segmentation[None].float(),
                    grid[None],
                    mode=self.mode_seg,
                    padding_mode=self.border_mode_seg,
                    align_corners=False
                )[0].to(segmentation.dtype)
else:
    ## 此处略

如果不是最近邻插值:

首先初始化一个用于返回的零张量,大小和传入的segmentation一样(传进来前已经是patch size的形状了):

## if self.mode_seg == 'nearest':
## 		代码略
## else:
result_seg = torch.zeros((segmentation.shape[0], *self.patch_size), dtype=segmentation.dtype)

接下来如果self.bg_style_seg_samplingTrue(nnU-Net V2设置为False):

## if self.bg_style_seg_sampling:
## 遍历所有通道
for c in range(segmentation.shape[0]):	
    ## 获取本通道内包含的所有标签值
    labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
    # if we only have 2 labels then we can save compute time
    ## 二分类的情况(背景和一个前景)
    if len(labels) == 2:
        ## 直接采样,并根据阈值(0.5)判定像素属于哪个类别
        out = grid_sample(
                ((segmentation[c] == labels[1]).float())[None, None],
                grid[None],
                mode=self.mode_seg,
                padding_mode=self.border_mode_seg,
                align_corners=False
            )[0][0] >= 0.5
        result_seg[c][out] = labels[1]
        result_seg[c][~out] = labels[0]
    else:	## 多分类情况
        ## 遍历所有标签
        for i, u in enumerate(labels):
            ## 对当前标签采样,可以看成是一个二分类,我和其他
            ## 阈值依旧是0.5
            result_seg[c][
                grid_sample(
                    ((segmentation[c] == u).float())[None, None],
                    grid[None],
                    mode=self.mode_seg,
                    padding_mode=self.border_mode_seg,
                    align_corners=False
                )[0][0] >= 0.5] = u

如果self.bg_style_seg_samplingFalse

## if self.bg_style_seg_sampling:
## 		代码略
## else:
## 遍历所有通道
for c in range(segmentation.shape[0]):
    ## 获取本通道内包含的所有标签值
    labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
    ## 创建tmp存储每个标签的采样结果
    tmp = torch.zeros((len(labels), *self.patch_size), dtype=torch.float16)
    ## 缩放因子
    scale_factor = 1000
    ## 创建和patch size相同大小的零张量,记录哪些像素已经被分配了类别
    done_mask = torch.zeros(*self.patch_size, dtype=torch.bool)
    ## 遍历所有标签
    for i, u in enumerate(labels):
        ## 对当前标签采样,与self.bg_style_seg_sampling为True的多分类采样不同的是
        ## 此处要对输入grid_sample的、当前通道的二值掩码((segmentation[c] == u).float())进行放大
        ## * scale_factor;判定类别的阈值也不一样,变为(0.7 * scale_factor)
        tmp[i] = grid_sample(((segmentation[c] == u).float() * scale_factor)[None, None], grid[None],
                                mode=self.mode_seg, padding_mode=self.border_mode_seg, align_corners=False)[0][0]
        mask = tmp[i] > (0.7 * scale_factor)
        ## 存入返回结果中
        result_seg[c][mask] = u
        ## 记录当前标签对应的像素已经分配了类别
        done_mask = done_mask | mask
    ## 对于未被分配类别的像素
    if not torch.all(done_mask):
        ## 由于阈值的存在,这些像素有采样的数值,但没有分配类别
        ## 对于这些像素,选择采样结果最大的标签
        result_seg[c][~done_mask] = labels[tmp[:, ~done_mask].argmax(0)]
    del tmp	## 处理临时变量

无论是不是最近邻插值,都已经处理完毕,最后返回数据增强后的segmentation:

del grid
return result_seg.contiguous()

其余函数

代码清晰,不做粘贴。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值