本文阅读的nnU-Net V2图像增强有亮度调整、对比度调整、低分辨率调整
各个类内的各个函数的调用关系见前文nnUNet V2代码——图像增强(一)的BasicTransform类
安装batchgeneratorsv2,nnU-Net V2关于图像增强的代码都在这个库中,点击链接,将其clone到本地后,在命令行进入文件夹内,pip install -e . 即可(注意-e后有个点)。
本文目录
一 MultiplicativeBrightnessTransform
该类包含亮度调整,继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码在batchgeneratorsv2 \ transforms \ intensity \ brightness.py文件中
MultiplicativeBrightnessTransform代码比SpatialTransform类、GaussianNoiseTransform类、GaussianBlurTransform类简洁,但代码逻辑一致
1. __init__函数
定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用
2. get_parameters函数
def get_parameters(self, **data_dict) -> dict:
## 获取image大小
shape = data_dict['image'].shape
## 确定哪些通道要施加亮度调整
### self.p_per_channel = 1,nnU-Net V2对每个通道都施加亮度调整
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
## 各通道同步施加相同的亮度调整
if self.synchronize_channels:
multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
## 各通道各自施加自己的亮度调整
else: ### self.synchronize_channels = False,nnU-Net V2不同步施加
multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
## 收集参数后返回
return {
'apply_to_channel': apply_to_channel,
'multipliers': multipliers
}
sample_scalar函数见nnUNet V2代码——图像增强(一)的其余函数
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
## 没有要调整的通道,直接返回
if len(params['apply_to_channel']) == 0:
return img
## 遍历施加亮度调整
for c, m in zip(params['apply_to_channel'], params['multipliers']):
img[c] *= m
return img
二 ContrastTransform
该类负责对比度调整,继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码在batchgeneratorsv2 \ transforms \ intensity \ contrast.py文件中
1. __init__函数
定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用
2. get_parameters函数
和MultiplicativeBrightnessTransform的get_parameters函数一致
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
if len(params['apply_to_channel']) == 0:
return img
## 遍历通道
for i in range(len(params['apply_to_channel'])):
c = params['apply_to_channel'][i]
## 获取图像某一通道的平均值
mean = img[c].mean()
## 是否保留数值范围,nnU-Net V2设置self.preserve_range = True
if self.preserve_range:
minm = img[c].min()
maxm = img[c].max()
## 对比度调整
img[c] -= mean
img[c] *= params['multipliers'][i]
img[c] += mean
## 是否保留数值范围
if self.preserve_range:
img[c].clamp_(minm, maxm)
return img
三 SimulateLowResolutionTransform类
该类负责施加低分辨率,继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码在batchgeneratorsv2 \ transforms \ spatial\ low_resolution.py文件中
1. __init__函数
变量名称 | 含义 |
---|---|
self.scale | 图像放缩范围 |
self.synchronize_channels | 各通道是否施加相同的低分辨率处理 |
self.synchronize_axes | 各轴是否施加相同的低分辨率处理 |
self.ignore_axes | 某轴不能施加低分辨率处理,与nnUNet V2代码——图像增强(一)的Convert3DTo2DTransform和Convert2DTo3DTransform有关 |
self.allowed_channels | 可能会施加低分辨率处理的通道 |
self.p_per_channel | 某通道施加低分辨率处理的概率 |
self.upmodes | 各维度采样方法 |
self.scale = scale
self.synchronize_channels = synchronize_channels
self.synchronize_axes = synchronize_axes
self.ignore_axes = ignore_axes
self.allowed_channels = allowed_channels
self.p_per_channel = p_per_channel
self.upmodes = {
1: 'linear',
2: 'bilinear',
3: 'trilinear'
}
2. get_parameters函数
def get_parameters(self, **data_dict) -> dict:
shape = data_dict['image'].shape
## nnU-Net V2设置为None,所有通道按概率施加低分辨率处理
if self.allowed_channels is None:
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
else:
apply_to_channel = [i for i in self.allowed_channels if torch.rand(1) < self.p_per_channel]
## nnU-Net V2设置为False
if self.synchronize_channels:
## nnU-Net V2设置为True
if self.synchronize_axes:
## 各通道、各轴施加相同的
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=None)] * (len(shape) - 1)] * len(apply_to_channel))
else:
## 各通道施加相同的,各轴施加各自的
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=d) for d in range(len(shape) - 1)]] * len(apply_to_channel))
else:
if self.synchronize_axes:
## 各轴施加相同的,各通道施加各自的
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=None)] * (len(shape) - 1) for c in apply_to_channel])
else:
## 各通道、各轴施加各自的
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=d) for d in range(len(shape) - 1)] for c in apply_to_channel])
## 对忽略的轴单独处理
if len(scales) > 0:
scales[:, self.ignore_axes] = 1
return {
'apply_to_channel': apply_to_channel,
'scales': scales
}
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
orig_shape = img.shape[1:]
# 注释机翻:我们无法对这些内容进行批处理,因为每个通道的下采样 SHAP 值会有所不同。
for c, s in zip(params['apply_to_channel'], params['scales']):
## 按照放缩尺度确定下采样后的图像大小
new_shape = [round(i * j.item()) for i, j in zip(orig_shape, s)]
## 使用某一种最近邻插值进行采样
downsampled = interpolate(img[c][None, None], new_shape, mode='nearest-exact')
## 还原图像大小
img[c] = interpolate(downsampled, orig_shape, mode=self.upmodes[img.ndim - 1])[0, 0]
return img