一文搞懂pytorch-image-models中的数据增强:CutMix与MixUp实现

一文搞懂pytorch-image-models中的数据增强:CutMix与MixUp实现

【免费下载链接】pytorch-image-models huggingface/pytorch-image-models: 是一个由 Hugging Face 开发维护的 PyTorch 视觉模型库,包含多个高性能的预训练模型,适用于图像识别、分类等视觉任务。 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

你是否还在为图像分类模型过拟合而烦恼?是否想通过简单的代码调整就能显著提升模型性能?本文将带你深入了解pytorch-image-models(简称timm)库中两种强大的数据增强技术——CutMix与MixUp的实现原理与使用方法,读完你将能够轻松在自己的项目中应用这两种技术,有效提升模型的泛化能力。

CutMix与MixUp的基本原理

在深度学习中,数据增强是提升模型性能的重要手段。CutMix和MixUp是两种近年来广泛应用的混合增强技术,它们通过不同的方式将两张图像混合,从而生成新的训练样本,增加数据的多样性,减少过拟合。

MixUp的核心思想是将两张图像按照一定的比例进行线性插值混合,同时对应的标签也进行相应的混合。这种方法可以让模型学习到图像之间的线性关系,增强模型的鲁棒性。其数学公式如下:

$\hat{x} = \lambda x_i + (1-\lambda) x_j$

$\hat{y} = \lambda y_i + (1-\lambda) y_j$

其中,$x_i$和$x_j$是两张原始图像,$y_i$和$y_j$是对应的标签,$\lambda$是从Beta分布中采样得到的混合系数。

CutMix则是另一种混合增强方法,它不是对整个图像进行混合,而是随机裁剪一张图像的一个区域,然后将其粘贴到另一张图像的对应区域,同时调整标签的混合比例。这种方法可以保留图像的局部特征,有助于模型学习到更具判别性的特征。

timm库中CutMix与MixUp的实现

timm库中CutMix和MixUp的实现主要集中在timm/data/mixup.py文件中。该文件定义了Mixup类和FastCollateMixup类,分别实现了基本的混合增强功能和高效的混合增强功能。

Mixup类的核心实现

Mixup类是timm库中实现CutMix和MixUp的核心类,它的构造函数接受多个参数,用于配置混合增强的各种参数,如混合系数的Beta分布参数、使用CutMix的概率、标签平滑系数等。

class Mixup:
    """ Mixup/Cutmix that applies different params to each element or whole batch

    Args:
        mixup_alpha (float): mixup alpha value, mixup is active if > 0.
        cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
        cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
        prob (float): probability of applying mixup or cutmix per batch or element
        switch_prob (float): probability of switching to cutmix instead of mixup when both are active
        mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
        correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
        label_smoothing (float): apply label smoothing to the mixed target tensor
        num_classes (int): number of classes for target
    """
    def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
                 mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.cutmix_minmax = cutmix_minmax
        if self.cutmix_minmax is not None:
            assert len(self.cutmix_minmax) == 2
            # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
            self.cutmix_alpha = 1.0
        self.mix_prob = prob
        self.switch_prob = switch_prob
        self.label_smoothing = label_smoothing
        self.num_classes = num_classes
        self.mode = mode
        self.correct_lam = correct_lam  # correct lambda based on clipped area for cutmix
        self.mixup_enabled = True  # set to false to disable mixing (intended tp be set by train loop)

Mixup类的__call__方法是其核心方法,它接受输入图像和标签,返回混合后的图像和标签。根据配置的mode参数,__call__方法会调用不同的混合函数,如_mix_batch_mix_pair_mix_elem,分别实现对整个批次、成对元素或单个元素的混合增强。

def __call__(self, x, target):
    assert len(x) % 2 == 0, 'Batch size should be even when using this'
    if self.mode == 'elem':
        lam = self._mix_elem(x)
    elif self.mode == 'pair':
        lam = self._mix_pair(x)
    else:
        lam = self._mix_batch(x)
    target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
    return x, target

CutMix的边界框生成

在CutMix中,边界框的生成是一个关键步骤。timm库中提供了rand_bbox函数和rand_bbox_minmax函数来生成不同类型的边界框。

rand_bbox函数根据混合系数$\lambda$生成一个随机的正方形边界框,其大小由$\lambda$决定。函数首先计算边界框的比例ratio = np.sqrt(1 - lam),然后根据图像的高度和宽度计算边界框的高度和宽度。最后,随机生成边界框的中心点坐标,并计算边界框的左上角和右下角坐标。

def rand_bbox(img_shape, lam, margin=0., count=None):
    """ Standard CutMix bounding-box
    Generates a random square bbox based on lambda value. This impl includes
    support for enforcing a border margin as percent of bbox dimensions.

    Args:
        img_shape (tuple): Image shape as tuple
        lam (float): Cutmix lambda value
        margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
        count (int): Number of bbox to generate
    """
    ratio = np.sqrt(1 - lam)
    img_h, img_w = img_shape[-2:]
    cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
    margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
    cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
    cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
    yl = np.clip(cy - cut_h // 2, 0, img_h)
    yh = np.clip(cy + cut_h // 2, 0, img_h)
    xl = np.clip(cx - cut_w // 2, 0, img_w)
    xh = np.clip(cx + cut_w // 2, 0, img_w)
    return yl, yh, xl, xh

rand_bbox_minmax函数则根据给定的最小和最大比例生成一个随机的矩形边界框,这种方法可以生成更灵活的边界框形状。

混合标签的生成

混合后的标签生成是通过mixup_target函数实现的。该函数首先将原始标签转换为one-hot编码,然后根据混合系数$\lambda$对两个标签进行加权平均,生成混合后的标签。同时,函数还支持标签平滑,通过设置smoothing参数可以减少模型对标签的过度自信。

def mixup_target(target, num_classes, lam=1., smoothing=0.0):
    off_value = smoothing / num_classes
    on_value = 1. - smoothing + off_value
    y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
    y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
    return y1 * lam + y2 * (1. - lam)

在训练中使用CutMix与MixUp

在timm库的训练脚本train.py中,可以很方便地启用CutMix和MixUp增强。通过设置相应的命令行参数,如--mixup--cutmix等,可以配置混合增强的各种参数。

命令行参数配置

train.py中,定义了多个与混合增强相关的命令行参数,用于配置Mixup类的各种参数:

parser.add_argument('--mixup', type=float, default=0.8,
                    help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=1.0,
                    help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
                    help='probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                    help='probability of switching to cutmix when both mixup and cutmix are enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
                    help='how to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

初始化Mixup对象

在训练脚本中,根据命令行参数初始化Mixup对象或FastCollateMixup对象。FastCollateMixupMixup的一个子类,它在数据加载时进行混合增强,可以提高训练效率。

mixup_args = dict(
    mixup_alpha=args.mixup,
    cutmix_alpha=args.cutmix,
    cutmix_minmax=args.cutmix_minmax,
    prob=args.mixup_prob,
    switch_prob=args.mixup_switch_prob,
    mode=args.mixup_mode,
    label_smoothing=args.label_smoothing,
    num_classes=num_classes,
)
if args.naflex_mixup:
    from timm.data import NaFlexMixup
    naflex_mixup_fn = NaFlexMixup(**mixup_args)
else:
    if args.prefetcher and not args.no_prefetcher:
        collate_fn = FastCollateMixup(**mixup_args)
    else:
        mixup_fn = Mixup(**mixup_args)

在训练循环中应用

在训练循环中,当进行前向传播时,调用Mixup对象的__call__方法对输入图像和标签进行混合增强:

if mixup_fn is not None:
    input, target = mixup_fn(input, target)

总结与展望

本文详细介绍了timm库中CutMix和MixUp两种数据增强技术的实现原理和使用方法。通过阅读timm/data/mixup.py源码,我们了解了混合增强的核心算法,包括混合系数的采样、边界框的生成、混合图像和标签的计算等。同时,通过分析train.py脚本,我们学习了如何在实际训练中配置和应用这些增强技术。

CutMix和MixUp作为简单有效的数据增强方法,已经在许多图像分类任务中取得了显著的性能提升。未来,随着深度学习技术的不断发展,可能会出现更多更有效的混合增强方法,如结合注意力机制的混合增强、动态调整混合策略的增强等。timm库作为一个活跃的开源项目,也会不断集成新的增强技术,为用户提供更强大的工具。

希望本文能够帮助你更好地理解和应用CutMix与MixUp技术。如果你对timm库中的数据增强还有其他疑问,欢迎查阅官方文档或源码,也可以在评论区留言讨论。最后,别忘了点赞、收藏、关注三连,下期我们将介绍timm库中其他强大的数据增强技术!

【免费下载链接】pytorch-image-models huggingface/pytorch-image-models: 是一个由 Hugging Face 开发维护的 PyTorch 视觉模型库,包含多个高性能的预训练模型,适用于图像识别、分类等视觉任务。 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值