eval()与train()(结合源码理解)

本文详细解析了PyTorch中eval()与train()函数的实质,重点阐述了它们在Batch Normalization(BN)层中的应用。eval()与train()主要影响BN层的全局统计特征moving_mean和moving_variance是否更新,从而区分训练和验证阶段。通过BN层的源码分析,揭示了这两个方法在模型运行过程中的核心作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

eval()与train()(结合源码理解)

结论

先上结论:eval和training实际上就是需要判断BN层(dorpout类似)中全局统计特征moving_mean和moving_variance是否需要迭代更新

1.eval()与train()的实质

请添加图片描述

pytorch.org上eval的源码
从源码上不难看出,eval()和trian()实际上就是遍历了所有子模块,然后把其中的training属性给设置为Ture和False。

2.结合BN层理解

那么这个training属性有什么用呢,我们直接定位BN层的源代码,参考理解
请添加图片描述
抛开中间的部分,就看绿色的部分可以发现module中的training到了这里,通过training给bn_training进行了赋值。看到这会很自然产生两个问题,这个training怎么来的,这个bn_training又是什么捏。
training从何而来
我们看看上面图片中,这个定义的类叫_BatchNorm,它继承自_NormBase而这个_NormBase又继承自module,如下图所示:
在这里插入图片描述
所以module的training属性也就自然而然继承了过来,接下来看看bn_training又有什么用。
在先前起到的_NormBase类的forward中的return结果中,我们找到了bn_training的用处,如下图所示:
在这里插入图片描述
这个参数有什么用呢,自然就要追溯一下F.batch_norm方法是如何使用的
参考飞桨的开发文档如下:
在这里插入图片描述
通过BN层的原理我们可以知道,要对每个Batch进行标准化需要的是两部分,一部分是全局统计特征(moving_mean和moving_variance),一部分是局部也就是minibatch的统计特征,μβ 和 σ2β。
前一部分通过对每个batch的数据特征计算迭代得到(类似参数的更新,但不完全是,后面会提到),后一部分对当前batch计算即可得到。

所以综上可以得到,eval和training实际上就是需要判断这个全局统计特征moving_mean和moving_variance是否需要迭代更新,这不也就对应了train和validation的两个阶段吗:)。

另外提到的BN层的迭代类似参数但不是的原因是:moving_mean和moving_variance不是模型的参数,而是作为一个buffer保存在模型中,所以其更新与参数的更新(loss反传)不一样,与require_grad要区别开哦。(这里后续会继续开一篇文章讲解buffer、parameter、module的相关)

以上就是eval()和train()的理解,感兴趣的小伙伴可以点点关注蹲蹲后续的一些分享,同时欢迎各位大佬评论区留言指正!:)

参考资料:
[1].飞桨参考文档https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/BatchNorm1D_cn.html#cn-api-nn-batchnorm1d
[2].pytorch官网
https://pytorch.org/

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值