本文阅读的nnU-Net V2图像增强有高斯噪声、高斯模糊
各个类内的各个函数的调用关系见前文nnUNet V2代码——图像增强(一)的BasicTransform类
安装batchgeneratorsv2,nnU-Net V2关于图像增强的代码都在这个库中,点击链接,将其clone到本地后,在命令行进入文件夹内,pip install -e . 即可(注意-e后有个点)。
一 ImageOnlyTransform
本类只对image施加数据增强,seg不操作
# 继承自BasicTransform
class ImageOnlyTransform(BasicTransform):
def apply(self, data_dict: dict, **params) -> dict:
# 只对image施加
if data_dict.get('image') is not None:
data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
return data_dict
二 GaussianNoiseTransform
GaussianNoiseTransform负责施加高斯噪声,继承自ImageOnlyTransform类,只对image施加,seg不施加。代码位于batchgeneratorsv2 \ transforms \ intensity\ gaussian_noise.py
1. __init__函数
定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用
2. get_parameters函数
本函数主要确定在哪些通道上使用高斯噪声以及确定高斯噪声的 σ σ σ 大小:
# 先获取图像大小
shape = data_dict['image'].shape
dct = {}
# 依概率确定各个通道是否施加高斯噪声
# nnU-Net V2对每一个通道都施加高斯噪声
dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel # self.p_per_channel = 1
# 施加高斯噪声的所有通道是否选用一样的sigmas(标准差),也就是是否同步
# self.synchronize_channels = True
# nnU-Net V2对所有通道施加相同标准差的高斯噪声
dct['sigmas'] = \
[sample_scalar(self.noise_variance, data_dict['image'])
for i in range(sum(dct['apply_to_channel']))] if not self.synchronize_channels \
else sample_scalar(self.noise_variance, data_dict['image'])
return dct
sample_scalar函数见nnUNet V2代码——图像增强(一)的其余函数
3. _apply_to_image函数
# 没有施加的通道就直接返回img
if sum(params['apply_to_channel']) == 0:
return img
# 调用_sample_gaussian_noise获取和img相同形状的高斯噪声
gaussian_noise = self._sample_gaussian_noise(img.shape, **params)
# 注入原图,返回
img[params['apply_to_channel']] += gaussian_noise
return img
4. _sample_gaussian_noise函数
# 如果params['sigmas']字段不是list类型,说明各通道同步施加相同的高斯噪声
if not isinstance(params['sigmas'], list):
# 确定施加的通道数目
num_channels = sum(params['apply_to_channel'])
# 创建一张高斯图,然后沿通道维度复制
gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:]))
gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1)))
# 如果params['sigmas']字段是list类型,说明各通道施加各自的高斯噪声
else:
# for循环遍历各个通道,创建独立的噪声图
gaussian = [
torch.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas']
]
# 将所有的高斯图沿通道维度拼接起来
gaussian = torch.cat(gaussian, dim=0)
return gaussian
三 GaussianBlurTransform
GaussianBlurTransform类负责施加高斯模糊,并且使用可分离的高斯滤波器来提升速度。继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码位于 batchgeneratorsv2 \ transforms \ noise \ gaussian_blur.py文件
1. __init__函数
定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用
2. get_parameters函数
# 获取image形状
shape = data_dict['image'].shape
# 获取image维度(二维 / 三维)
dims = len(shape) - 1
dct = {}
# 依概率确定各个通道是否施加高斯模糊
# self.p_per_channel = 0.5
dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel
# 如果所有通道采用一样的模糊核
# self.synchronize_channels = False
# nnU-Net V2对模糊的通道采样不同的模糊核
if self.synchronize_axes:
dct['sigmas'] = \
[[sample_scalar(self.blur_sigma, shape, dim=None)] * dims
for _ in range(sum(dct['apply_to_channel']))] \
if not self.synchronize_channels else \
[sample_scalar(self.blur_sigma, shape, dim=None)] * dims
# 如果各个通道采用各自的模糊核
else:
dct['sigmas'] = \
[[sample_scalar(self.blur_sigma, shape, dim=i + 1) for i in range(dims)]
for _ in range(sum(dct['apply_to_channel']))] \
if not self.synchronize_channels else \
[sample_scalar(self.blur_sigma, shape[i + 1]) for i in range(dims)]
return dct
3. _apply_to_image函数
## 没有施加的通道就直接返回img
if len(params['apply_to_channel']) == 0:
return img
dim = len(img.shape[1:])
## 如果所有通道采用一样的模糊核
if self.synchronize_channels:
# 注释机翻:我们可以一次性计算完成,因为卷积实现支持任意输入通道(通过扩展的核)。
## 遍历每个空间维度(X/Y/Z)
for d in range(dim):
## 如果不进行测试,也就是不区分普通卷积和FFT卷积
if not self.benchmark:
img[params['apply_to_channel']] = blur_dimension(img[params['apply_to_channel']], params['sigmas'][d], d)
## 如果区分普通卷积和FFT卷积,调用self._benchmark_wrapper函数
else:
img[params['apply_to_channel']] = self._benchmark_wrapper(img[params['apply_to_channel']], params['sigmas'][d], d)
else:
# 注释机翻:我们需要遍历所有通道,为每个通道构建核等。
## 获取需要高斯模糊的通道索引,然后遍历它
idx = np.where(params['apply_to_channel'])[0]
for j, i in enumerate(idx):
## 遍历每个空间维度(X/Y/Z)
for d in range(dim):
## 也是判断是否开启测试,同上
if not self.benchmark:
img[i:i+1] = blur_dimension(img[i:i+1], params['sigmas'][j][d], d)
else:
img[i:i+1] = self._benchmark_wrapper(img[i:i+1], params['sigmas'][j][d], d)
## 返回高斯模糊后的图像
return img
4. _benchmark_wrapper函数
本函数通过多次比较,选择速度更快的卷积方式(普通卷积或FFT卷积),然后调用blur_dimension函数实现高斯模糊
def _benchmark_wrapper(self, img: torch.Tensor, sigma: float, dim_to_blur: int):
# 计算当前sigma对应的核大小
kernel_size = _compute_kernel_size(sigma)
shp = img.shape[dim_to_blur + 1]
# 检查是否已有缓存结果
if shp in self.benchmark_use_fft.keys() and kernel_size in self.benchmark_use_fft[shp].keys():
# 有缓存直接调用blur_dimension
return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])
else:
# 初始化缓存结构
if shp not in self.benchmark_use_fft.keys():
self.benchmark_use_fft[shp] = {}
# 基准测试
# 避免污染原始图像
dummy_img = deepcopy(img)
times_nonfft = []
for _ in range(self.benchmark_num_runs): # 多次运行减少误差
st = time()
blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=False)
times_nonfft.append(time() - st)
times_fft = []
for _ in range(self.benchmark_num_runs):
st = time()
blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=True)
times_fft.append(time() - st)
# 比较中位时间并缓存结果
self.benchmark_use_fft[shp][kernel_size] = np.median(times_fft) < np.median(times_nonfft)
# 按核大小排序
self.benchmark_use_fft[shp] = dict(sorted(self.benchmark_use_fft[shp].items()))
# 返回实际模糊结果,调用blur_dimension函数
return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])
5. _build_kernel 函数
构建高斯核,先确定高斯核的大小,再确定该大小能覆盖多少高斯分布,最后放缩数值,让总和为1
def _build_kernel(sigma: float, truncate: float = 4) -> torch.Tensor:
"""
sigma: 高斯分布的标准差
truncate: 截断范围系数,决定核包含高斯分布的范围,
默认为4表示范围是[-4*sigma, 4*sigma]
±(4*sigma)已经覆盖了绝大部分高斯分布
"""
# 计算核大小
kernel_size = _compute_kernel_size(sigma, truncate=truncate)
# 计算核半径,核大小一定是奇数,所以先减1再除2
ksize_half = (kernel_size - 1) * 0.5
# 根据核大小,生成等间距坐标点,eg: -2 -1 0 1 2
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
# 计算高斯概率密度函数
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
# 放缩数值大小,让所有元素和为1
kernel1d = pdf / pdf.sum()
return kernel1d # 形状是1维
6. _round_to_nearest_odd 函数
def _round_to_nearest_odd(n):
# 将一个数字舍入至最近的奇数
# 先四舍五入
rounded = round(n)
# If the rounded number is odd, return it
# 如果是奇数,直接返回
if rounded % 2 == 1:
return rounded
# If the rounded number is even, adjust to the nearest odd number
# 如果是偶数,调整为奇数后返回
return rounded + 1 if n - rounded >= 0 else rounded - 1
7. _compute_kernel_size 函数
def _compute_kernel_size(sigma, truncate: float = 4):
# 参数同_build_kernel 函数
# 计算核大小,加0.5确保在sigma很小时,kernel size大于1
ksize = _round_to_nearest_odd(sigma * truncate + 0.5)
return ksize
8. blur_dimension函数
def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6):
"""
img:输入图像张量,形状为(C, X)、(C, X, Y)或(C, X, Y, Z),其中C是通道维度,X、Y、Z是空间维度。
sigma:高斯核的标准差。
dim_to_blur:应用高斯模糊的维度(0对应X,1对应Y,2对应Z)。
"""
## 断言判断
assert img.ndim - 1 > dim_to_blur, "dim_to_blur must be a valid spatial dimension of the input image."
# 注释机翻:根据图像维度对核进行调整。
spatial_dims = img.ndim - 1 # 注释机翻:输入图像中的空间维度数量。
## spatial_dims 和 dim_to_blur要区分开
## 调用_build_kernel函数构建kernel
kernel = _build_kernel(sigma, truncate=truncate)
ksize = kernel.shape[0]
# 注释机翻:根据空间维度的数量动态设置填充、卷积操作和核的形状。
conv_ops = {1: conv1d, 2: conv2d, 3: conv3d}
## 根据输入图像的空间维度(1D、2D、3D)选择对应的卷积函数(conv1d、conv2d、conv3d)
## 若启用 force_use_fft,则使用FFT卷积加速计算
if force_use_fft is not None:
conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv
else:
conv_op = conv_ops[spatial_dims]
# 注释机翻:根据指定的模糊维度和输入维度调整核和填充。
## 1D图像
if spatial_dims == 1:
## kernel 形状为(1, 1, ksize)
kernel = kernel[None, None, :]
## padding为图像左右填充像素的数量,均为ksize//2
padding = [ksize // 2, ksize // 2]
## 2D图像
elif spatial_dims == 2:
## 如果沿0轴模糊
if dim_to_blur == 0:
## kernel 形状为(1, 1, ksize, 1)
kernel = kernel[None, None, :, None]
## padding只在0轴方向填充
padding = [0, 0, ksize // 2, ksize // 2]
## 如果沿1轴模糊
else: # dim_to_blur == 1
## kernel 形状为(1, 1, 1, ksize)
kernel = kernel[None, None, None, :]
## padding只在1轴方向填充
padding = [ksize // 2, ksize // 2, 0, 0]
## 3D图像,类比上面
else: # spatial_dims == 3
# 注释机翻:根据模糊维度扩展核并调整填充。
if dim_to_blur == 0:
kernel = kernel[None, None, :, None, None]
padding = [0, 0, 0, 0, ksize // 2, ksize // 2]
elif dim_to_blur == 1:
kernel = kernel[None, None, None, :, None]
padding = [0, 0, ksize // 2, ksize // 2, 0, 0]
else: # dim_to_blur == 2
kernel = kernel[None, None, None, None, :]
padding = [ksize // 2, ksize // 2, 0, 0, 0, 0]
# 注释机翻:应用padding
img_padded = pad(img, padding, mode="reflect") ## 对称填充
# 注释机翻:应用卷积
# 注释机翻:记住权重的形状是 [c_out, c_in, ...]
## 通过分组卷积,对图像各个通道分别处理
img_blurred = conv_op(img_padded[None], kernel.expand(img_padded.shape[0], *[-1] * (kernel.ndim - 1)), groups=img_padded.shape[0])[0]
## 返回模糊后的图像
return img_blurred