nnUNet V2代码——图像增强(二)

本文阅读的nnU-Net V2图像增强有高斯噪声高斯模糊

各个类内的各个函数的调用关系见前文nnUNet V2代码——图像增强(一)BasicTransform

安装batchgeneratorsv2,nnU-Net V2关于图像增强的代码都在这个库中,点击链接,将其clone到本地后,在命令行进入文件夹内,pip install -e . 即可(注意-e后有个点)。

一 ImageOnlyTransform

本类只对image施加数据增强,seg不操作

# 继承自BasicTransform
class ImageOnlyTransform(BasicTransform):
    def apply(self, data_dict: dict, **params) -> dict:
        # 只对image施加
        if data_dict.get('image') is not None:
            data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
        return data_dict

二 GaussianNoiseTransform

GaussianNoiseTransform负责施加高斯噪声,继承自ImageOnlyTransform类,只对image施加,seg不施加。代码位于batchgeneratorsv2 \ transforms \ intensity\ gaussian_noise.py

1. __init__函数

定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用

2. get_parameters函数

本函数主要确定在哪些通道上使用高斯噪声以及确定高斯噪声的 σ σ σ 大小:

# 先获取图像大小
shape = data_dict['image'].shape
dct = {}
# 依概率确定各个通道是否施加高斯噪声
# nnU-Net V2对每一个通道都施加高斯噪声
dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel	# self.p_per_channel = 1 
# 施加高斯噪声的所有通道是否选用一样的sigmas(标准差),也就是是否同步
# self.synchronize_channels = True
# nnU-Net V2对所有通道施加相同标准差的高斯噪声
dct['sigmas'] = \
    [sample_scalar(self.noise_variance, data_dict['image'])
        for i in range(sum(dct['apply_to_channel']))] if not self.synchronize_channels \
        else sample_scalar(self.noise_variance, data_dict['image'])
return dct

sample_scalar函数见nnUNet V2代码——图像增强(一)其余函数

3. _apply_to_image函数

# 没有施加的通道就直接返回img
if sum(params['apply_to_channel']) == 0:
    return img
# 调用_sample_gaussian_noise获取和img相同形状的高斯噪声
gaussian_noise = self._sample_gaussian_noise(img.shape, **params)
# 注入原图,返回
img[params['apply_to_channel']] += gaussian_noise
return img

4. _sample_gaussian_noise函数

# 如果params['sigmas']字段不是list类型,说明各通道同步施加相同的高斯噪声
if not isinstance(params['sigmas'], list):
    # 确定施加的通道数目
    num_channels = sum(params['apply_to_channel'])
    # 创建一张高斯图,然后沿通道维度复制
    gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:]))
    gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1)))
# 如果params['sigmas']字段是list类型,说明各通道施加各自的高斯噪声
else:
   # for循环遍历各个通道,创建独立的噪声图
    gaussian = [
        torch.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas']
    ]
    # 将所有的高斯图沿通道维度拼接起来
    gaussian = torch.cat(gaussian, dim=0)
return gaussian

三 GaussianBlurTransform

GaussianBlurTransform类负责施加高斯模糊,并且使用可分离的高斯滤波器来提升速度。继承自ImageOnlyTransform类,只对image施加,seg不施加。

代码位于 batchgeneratorsv2 \ transforms \ noise \ gaussian_blur.py文件

1. __init__函数

定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用

2. get_parameters函数

# 获取image形状
shape = data_dict['image'].shape
# 获取image维度(二维 / 三维)
dims = len(shape) - 1
dct = {}
# 依概率确定各个通道是否施加高斯模糊
# self.p_per_channel = 0.5
dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel
# 如果所有通道采用一样的模糊核
# self.synchronize_channels = False
# nnU-Net V2对模糊的通道采样不同的模糊核
if self.synchronize_axes:
    dct['sigmas'] = \
        [[sample_scalar(self.blur_sigma, shape, dim=None)] * dims
            for _ in range(sum(dct['apply_to_channel']))] \
            if not self.synchronize_channels else \
            [sample_scalar(self.blur_sigma, shape, dim=None)] * dims
# 如果各个通道采用各自的模糊核
else:
    dct['sigmas'] = \
        [[sample_scalar(self.blur_sigma, shape, dim=i + 1) for i in range(dims)]
            for _ in range(sum(dct['apply_to_channel']))] \
            if not self.synchronize_channels else \
            [sample_scalar(self.blur_sigma, shape[i + 1]) for i in range(dims)]
return dct

3. _apply_to_image函数

## 没有施加的通道就直接返回img
if len(params['apply_to_channel']) == 0:
    return img
dim = len(img.shape[1:])

## 如果所有通道采用一样的模糊核
if self.synchronize_channels:
    # 注释机翻:我们可以一次性计算完成,因为卷积实现支持任意输入通道(通过扩展的核)。
    ## 遍历每个空间维度(X/Y/Z)
    for d in range(dim):
    	## 如果不进行测试,也就是不区分普通卷积和FFT卷积
        if not self.benchmark:
            img[params['apply_to_channel']] = blur_dimension(img[params['apply_to_channel']], params['sigmas'][d], d)
        ## 如果区分普通卷积和FFT卷积,调用self._benchmark_wrapper函数
        else:
            img[params['apply_to_channel']] = self._benchmark_wrapper(img[params['apply_to_channel']], params['sigmas'][d], d)
else:
    # 注释机翻:我们需要遍历所有通道,为每个通道构建核等。
    ## 获取需要高斯模糊的通道索引,然后遍历它
    idx = np.where(params['apply_to_channel'])[0]
    for j, i in enumerate(idx):
    	## 遍历每个空间维度(X/Y/Z)
        for d in range(dim):
            ## 也是判断是否开启测试,同上
            if not self.benchmark:
                img[i:i+1] = blur_dimension(img[i:i+1], params['sigmas'][j][d], d)
            else:
                img[i:i+1] = self._benchmark_wrapper(img[i:i+1], params['sigmas'][j][d], d)
## 返回高斯模糊后的图像
return img

4. _benchmark_wrapper函数

本函数通过多次比较,选择速度更快的卷积方式(普通卷积或FFT卷积),然后调用blur_dimension函数实现高斯模糊

def _benchmark_wrapper(self, img: torch.Tensor, sigma: float, dim_to_blur: int):
	# 计算当前sigma对应的核大小
    kernel_size = _compute_kernel_size(sigma)
    shp = img.shape[dim_to_blur + 1]
    
    # 检查是否已有缓存结果
    if shp in self.benchmark_use_fft.keys() and kernel_size in self.benchmark_use_fft[shp].keys():
    	# 有缓存直接调用blur_dimension
        return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])
    else:
        # 初始化缓存结构
        if shp not in self.benchmark_use_fft.keys():
            self.benchmark_use_fft[shp] = {}
        # 基准测试
        # 避免污染原始图像
        dummy_img = deepcopy(img)
        times_nonfft = []
        for _ in range(self.benchmark_num_runs):	# 多次运行减少误差
            st = time()
            blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=False)
            times_nonfft.append(time() - st)
        times_fft = []
        for _ in range(self.benchmark_num_runs):
            st = time()
            blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=True)
            times_fft.append(time() - st)
        # 比较中位时间并缓存结果
        self.benchmark_use_fft[shp][kernel_size] = np.median(times_fft) < np.median(times_nonfft)
        # 按核大小排序
        self.benchmark_use_fft[shp] = dict(sorted(self.benchmark_use_fft[shp].items()))	
        # 返回实际模糊结果,调用blur_dimension函数
        return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])

5. _build_kernel 函数

构建高斯核,先确定高斯核的大小,再确定该大小能覆盖多少高斯分布,最后放缩数值,让总和为1

def _build_kernel(sigma: float, truncate: float = 4) -> torch.Tensor:
    """
    sigma: 高斯分布的标准差
    truncate: 截断范围系数,决定核包含高斯分布的范围,
              默认为4表示范围是[-4*sigma, 4*sigma]
              ±(4*sigma)已经覆盖了绝大部分高斯分布
    """
    # 计算核大小
    kernel_size = _compute_kernel_size(sigma, truncate=truncate)
    # 计算核半径,核大小一定是奇数,所以先减1再除2
    ksize_half = (kernel_size - 1) * 0.5
	# 根据核大小,生成等间距坐标点,eg: -2 -1 0 1 2
    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    # 计算高斯概率密度函数
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    # 放缩数值大小,让所有元素和为1
    kernel1d = pdf / pdf.sum()

    return kernel1d	# 形状是1维

6. _round_to_nearest_odd 函数

def _round_to_nearest_odd(n):
    # 将一个数字舍入至最近的奇数
    # 先四舍五入
    rounded = round(n)
    # If the rounded number is odd, return it
    # 如果是奇数,直接返回
    if rounded % 2 == 1:
        return rounded
    # If the rounded number is even, adjust to the nearest odd number
    # 如果是偶数,调整为奇数后返回
    return rounded + 1 if n - rounded >= 0 else rounded - 1

7. _compute_kernel_size 函数

def _compute_kernel_size(sigma, truncate: float = 4):
    # 参数同_build_kernel 函数
   	# 计算核大小,加0.5确保在sigma很小时,kernel size大于1
    ksize = _round_to_nearest_odd(sigma * truncate + 0.5)
    return ksize

8. blur_dimension函数

def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6):
    """
	img:输入图像张量,形状为(C, X)、(C, X, Y)或(C, X, Y, Z),其中C是通道维度,X、Y、Z是空间维度。
	sigma:高斯核的标准差。
	dim_to_blur:应用高斯模糊的维度(0对应X,1对应Y,2对应Z)。
    """
    ## 断言判断
    assert img.ndim - 1 > dim_to_blur, "dim_to_blur must be a valid spatial dimension of the input image."
    # 注释机翻:根据图像维度对核进行调整。
    spatial_dims = img.ndim - 1  # 注释机翻:输入图像中的空间维度数量。
    ## spatial_dims 和 dim_to_blur要区分开
    ## 调用_build_kernel函数构建kernel
    kernel = _build_kernel(sigma, truncate=truncate)

    ksize = kernel.shape[0]

    # 注释机翻:根据空间维度的数量动态设置填充、卷积操作和核的形状。
    conv_ops = {1: conv1d, 2: conv2d, 3: conv3d}
    ## 根据输入图像的空间维度(1D、2D、3D)选择对应的卷积函数(conv1d、conv2d、conv3d)
    ## 若启用 force_use_fft,则使用FFT卷积加速计算
    if force_use_fft is not None:
        conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv
    else:
        conv_op = conv_ops[spatial_dims]

    # 注释机翻:根据指定的模糊维度和输入维度调整核和填充。
    ## 1D图像
    if spatial_dims == 1:
    	## kernel 形状为(1, 1, ksize)
        kernel = kernel[None, None, :]
        ## padding为图像左右填充像素的数量,均为ksize//2
        padding = [ksize // 2, ksize // 2]
    ## 2D图像
    elif spatial_dims == 2:
    	## 如果沿0轴模糊
        if dim_to_blur == 0:
            ## kernel 形状为(1, 1, ksize, 1)
            kernel = kernel[None, None, :, None]
            ## padding只在0轴方向填充
            padding = [0, 0, ksize // 2, ksize // 2]
        ## 如果沿1轴模糊
        else:  # dim_to_blur == 1
            ## kernel 形状为(1, 1, 1, ksize)
            kernel = kernel[None, None, None, :]
            ## padding只在1轴方向填充
            padding = [ksize // 2, ksize // 2, 0, 0]
    ## 3D图像,类比上面
    else:  # spatial_dims == 3
        # 注释机翻:根据模糊维度扩展核并调整填充。
        if dim_to_blur == 0:
            kernel = kernel[None, None, :, None, None]
            padding = [0, 0, 0, 0, ksize // 2, ksize // 2]
        elif dim_to_blur == 1:
            kernel = kernel[None, None, None, :, None]
            padding = [0, 0, ksize // 2, ksize // 2, 0, 0]
        else:  # dim_to_blur == 2
            kernel = kernel[None, None, None, None, :]
            padding = [ksize // 2, ksize // 2, 0, 0, 0, 0]

    # 注释机翻:应用padding
    img_padded = pad(img, padding, mode="reflect")	## 对称填充

    # 注释机翻:应用卷积
    # 注释机翻:记住权重的形状是 [c_out, c_in, ...]
    ## 通过分组卷积,对图像各个通道分别处理
    img_blurred = conv_op(img_padded[None], kernel.expand(img_padded.shape[0], *[-1] * (kernel.ndim - 1)), groups=img_padded.shape[0])[0]
    ## 返回模糊后的图像
    return img_blurred

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

开栈

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值