一文搞懂pytorch-image-models中的数据增强:CutMix与MixUp实现
你是否还在为图像分类模型过拟合而烦恼?是否想通过简单的代码调整就能显著提升模型性能?本文将带你深入了解pytorch-image-models(简称timm)库中两种强大的数据增强技术——CutMix与MixUp的实现原理与使用方法,读完你将能够轻松在自己的项目中应用这两种技术,有效提升模型的泛化能力。
CutMix与MixUp的基本原理
在深度学习中,数据增强是提升模型性能的重要手段。CutMix和MixUp是两种近年来广泛应用的混合增强技术,它们通过不同的方式将两张图像混合,从而生成新的训练样本,增加数据的多样性,减少过拟合。
MixUp的核心思想是将两张图像按照一定的比例进行线性插值混合,同时对应的标签也进行相应的混合。这种方法可以让模型学习到图像之间的线性关系,增强模型的鲁棒性。其数学公式如下:
$\hat{x} = \lambda x_i + (1-\lambda) x_j$
$\hat{y} = \lambda y_i + (1-\lambda) y_j$
其中,$x_i$和$x_j$是两张原始图像,$y_i$和$y_j$是对应的标签,$\lambda$是从Beta分布中采样得到的混合系数。
CutMix则是另一种混合增强方法,它不是对整个图像进行混合,而是随机裁剪一张图像的一个区域,然后将其粘贴到另一张图像的对应区域,同时调整标签的混合比例。这种方法可以保留图像的局部特征,有助于模型学习到更具判别性的特征。
timm库中CutMix与MixUp的实现
timm库中CutMix和MixUp的实现主要集中在timm/data/mixup.py文件中。该文件定义了Mixup类和FastCollateMixup类,分别实现了基本的混合增强功能和高效的混合增强功能。
Mixup类的核心实现
Mixup类是timm库中实现CutMix和MixUp的核心类,它的构造函数接受多个参数,用于配置混合增强的各种参数,如混合系数的Beta分布参数、使用CutMix的概率、标签平滑系数等。
class Mixup:
""" Mixup/Cutmix that applies different params to each element or whole batch
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self.cutmix_alpha = 1.0
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
Mixup类的__call__方法是其核心方法,它接受输入图像和标签,返回混合后的图像和标签。根据配置的mode参数,__call__方法会调用不同的混合函数,如_mix_batch、_mix_pair或_mix_elem,分别实现对整个批次、成对元素或单个元素的混合增强。
def __call__(self, x, target):
assert len(x) % 2 == 0, 'Batch size should be even when using this'
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x, target
CutMix的边界框生成
在CutMix中,边界框的生成是一个关键步骤。timm库中提供了rand_bbox函数和rand_bbox_minmax函数来生成不同类型的边界框。
rand_bbox函数根据混合系数$\lambda$生成一个随机的正方形边界框,其大小由$\lambda$决定。函数首先计算边界框的比例ratio = np.sqrt(1 - lam),然后根据图像的高度和宽度计算边界框的高度和宽度。最后,随机生成边界框的中心点坐标,并计算边界框的左上角和右下角坐标。
def rand_bbox(img_shape, lam, margin=0., count=None):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.
Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
xl = np.clip(cx - cut_w // 2, 0, img_w)
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh
rand_bbox_minmax函数则根据给定的最小和最大比例生成一个随机的矩形边界框,这种方法可以生成更灵活的边界框形状。
混合标签的生成
混合后的标签生成是通过mixup_target函数实现的。该函数首先将原始标签转换为one-hot编码,然后根据混合系数$\lambda$对两个标签进行加权平均,生成混合后的标签。同时,函数还支持标签平滑,通过设置smoothing参数可以减少模型对标签的过度自信。
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
return y1 * lam + y2 * (1. - lam)
在训练中使用CutMix与MixUp
在timm库的训练脚本train.py中,可以很方便地启用CutMix和MixUp增强。通过设置相应的命令行参数,如--mixup、--cutmix等,可以配置混合增强的各种参数。
命令行参数配置
在train.py中,定义了多个与混合增强相关的命令行参数,用于配置Mixup类的各种参数:
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='probability of switching to cutmix when both mixup and cutmix are enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='how to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
初始化Mixup对象
在训练脚本中,根据命令行参数初始化Mixup对象或FastCollateMixup对象。FastCollateMixup是Mixup的一个子类,它在数据加载时进行混合增强,可以提高训练效率。
mixup_args = dict(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.label_smoothing,
num_classes=num_classes,
)
if args.naflex_mixup:
from timm.data import NaFlexMixup
naflex_mixup_fn = NaFlexMixup(**mixup_args)
else:
if args.prefetcher and not args.no_prefetcher:
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
在训练循环中应用
在训练循环中,当进行前向传播时,调用Mixup对象的__call__方法对输入图像和标签进行混合增强:
if mixup_fn is not None:
input, target = mixup_fn(input, target)
总结与展望
本文详细介绍了timm库中CutMix和MixUp两种数据增强技术的实现原理和使用方法。通过阅读timm/data/mixup.py源码,我们了解了混合增强的核心算法,包括混合系数的采样、边界框的生成、混合图像和标签的计算等。同时,通过分析train.py脚本,我们学习了如何在实际训练中配置和应用这些增强技术。
CutMix和MixUp作为简单有效的数据增强方法,已经在许多图像分类任务中取得了显著的性能提升。未来,随着深度学习技术的不断发展,可能会出现更多更有效的混合增强方法,如结合注意力机制的混合增强、动态调整混合策略的增强等。timm库作为一个活跃的开源项目,也会不断集成新的增强技术,为用户提供更强大的工具。
希望本文能够帮助你更好地理解和应用CutMix与MixUp技术。如果你对timm库中的数据增强还有其他疑问,欢迎查阅官方文档或源码,也可以在评论区留言讨论。最后,别忘了点赞、收藏、关注三连,下期我们将介绍timm库中其他强大的数据增强技术!
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



