Debug系列 GroupNorm和BatchNorm出现Nan或inf的情况

前言

在复现别人论文的实验结果时,按照README乖乖做完之后,却发现损失函数的走向十分诡异,具体表现为在cifar10数据集上运行2万步以前,一切风平浪静,祥和美好,但是突然loss就全部变为nan了。令人费解。在Debug两天后,终于发现并修改了问题。令人发指。

在此记录,也希望各位以后遇到此类问题,能够快速解决提供参考。

如果你也想试试,就看这个论文 论文连接

这两个函数做了什么

x ′ = x − u i σ i 2 + ϵ ∗ W + γ x'=\frac{x - u_i}{\sqrt{\sigma_i^2+\epsilon}}*W + \gamma x=σi2+ϵ xuiW+γ
Normalize实际上是将当前数据 x x x的分布转换为一个标准正态分布,即均值 u i = 0 u_i=0 ui=0,方差 σ i = 1 \sigma_i=1 σi=1。后面的 W , γ W,\gamma W,γ是可以训练的参数, ϵ \epsilon ϵ是为了实现数学稳定,在分母上增添的小量。

在神经网络中添加这个层,可以改变数据的分布,从而使得不同阶段数据分布相同,从而实现加速神经网络收敛,提高表达能力的效果。

需要说明的是,这两个函数会受到model.eval()和model.train()的影响。

这两个函数本身的区别在于,一个是按batch进行normalization,另一个则是按照channel进行normalization。

可能出现的问题

根据公式,我们会很容易发现,每一轮次,当前层都会计算当前数据batch中的均值和方差。那么显而易见,当数据分布不合理,或出现问题时,均值和方差有可能计算得到inf或nan。

此外为什么有时候,在训练的时候没问题,测试的时候却出现问题了呢?

还有为什么我的模型loss很好,输出却总是黑色图像呢?

为什么loss一开始很好,突然就不好了呢?

解决方法

train和eval

最容易解决的方式便是,如果发现模型只在测试中出现问题,那么不妨一直采用model.train的模式来进行预测。

batchsize或channel设置过小

这两个网络层每次计算实际上是会根据当前数据的batchsize或channel数进行的,那么过小的batchsize当然会导致算法不稳定。一般来说单GPU,batchsize定在32是一个合理的范围,而channel的话我是采用了128以上。

可训练参数的问题

这一说法的话,直接冻结参数即可,以免受到模型收敛的影响。

数值溢出

对了,这就是我最想说的!!!!

你可能会觉得奇怪,这为什么会数值溢出呢?如果没有觉得奇怪,可能大多数人都会想到,是不是 σ = 0 \sigma=0 σ=0?但是转念一想,我不是加了一个 ϵ \epsilon ϵ吗?是的,一开始我也是这么认为的。但是在经过不断Debug的过程中,发现这都不是问题的关键。

可能聪明的你猜到了,这可能与输入数据x有关,但是x怎么会导致数值溢出呢?这是因为,在Normalization的过程中,计算得到的均值和方差可能是很大的值,而这个值的大小如果一旦超过这两个变量本身能表达的最大大小,那么就会导致数值溢出,从而产生inf和nan。

这里可能与python的直觉不符,但是由于pytorch中存在 torch.float16之类的类型,其最大能表示的数值甚至不超过1e6,太抽象了。因此,需要手动将变量的type变为torch.float32以上,才能更好的计算。

如果还是不行,附以下代码,是我手写的GroupNorm,里面可以用clip强行划定 σ \sigma σ的值,从而避免出现inf,当然如果方差本来就是很大,这样clip实际上并不能使数据的分布转变为标准正态分布,但是在超不是特别多的情况下,这种clip实际上是一种妥协的方法,最差情况,即方差接近无限大,那么这样clip实际上相当于将整个GroupNorm的功能给作废。这并不影响整体模型效果,无非是GroupNorm的优良性质没了。

class GroupNorm(nn.GroupNorm):

    def forward(self, x):

        if len(x.shape) == 4:
            N,C,H,W = x.size()
            G = self.num_groups
            assert C % G == 0

            x = x.view(N,G,-1)
            mean = x.mean(-1, keepdim=True)
            var = th.clip(x.var(-1, keepdim=True), max=1e7)

            x = (x-mean) / (var+self.eps).sqrt()
            x = x.view(N,C,H,W)
        else:
            N,C,H = x.size()
            G = self.num_groups
            assert C % G == 0

            x = x.view(N,G,-1)
            mean = x.mean(-1, keepdim=True)
            var = th.clip(x.var(-1, keepdim=True), max=1e7)

            x = (x-mean) / (var+self.eps).sqrt()
            x = x.view(N,C,H)
        return x

其它的方法

之前看到有一个说法是可以冻结均值和方差????不知道怎么做,可以查一下。

内容概要:本文档是详尽的 Android SDK 中文帮助文档,介绍了 Android SDK 的核心概念、组件、开发环境搭建、基础开发流程及常用工具使用指南。首先解释了 Android SDK 的定义及其核心价值,即提供标准化开发环境,使开发者能高效构建、测试、优化 Android 应用。接着详细列出了 SDK 的核心组件,包括 Android Studio、SDK Tools、Platform Tools、Build Tools、Android 平台版本和系统镜像。随后,文档提供了详细的环境搭建步骤,适用于 Windows、macOS 和 Linux 系统,并介绍了基础开发流程,以“Hello World”为例展示了从创建项目到运行应用的全过程。此外,还深入讲解了 ADB、AVD Manager 和 SDK Manager 等核心工具的功能和使用方法。最后,文档涵盖了调试与优化工具(如 Logcat、Profiler 和 Layout Inspector)、关键开发技巧(如多版本 SDK 兼容、Jetpack 库的使用和资源文件管理)以及常见问题的解决方案。 适合人群:具有初步编程知识,希望深入了解 Android 应用开发的开发者,尤其是新手开发者和有一定经验但需要系统化学习 Android SDK 的技术人员。 使用场景及目标:①帮助开发者快速搭建 Android 开发环境;②指导开发者完成基础应用开发,理解核心工具的使用;③提高开发效率,掌握调试与优化技巧;④解决常见开发过程中遇到的问题。 阅读建议:此文档内容全面且实用,建议读者按照章节顺序逐步学习,结合实际开发项目进行练习,尤其要注意动手实践环境搭建和基础开发流程,同时参考提供的扩展学习资源,进一步提升开发技能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值