MECTA: Memory-Economic Continual Test-time Adaptation

许多机器学习应用需要部署良好训练的深度学习从大数据集到分布外数据(out-of-distribution)和动态改变的环境。现有的一些研究目标是通过持续测试时适应(Continual Test-times Adaption,CTA)解决这个现有的研究挑战。无监督、资源受限与动态测试时环境使得CTA是一个挑战性的学习问题并希望一个自监督、高效且稳定的解决方案。先前的研究Tent,EATA指出基于所有测试时数据不进行任何训练更新BN层显著提升了OOD性能。Tent通过更新少量参数以高效方式最小化预测熵实现显著提升测试时性能。EATA提高采样效率,避免分布内数据的灾难性遗忘。

尽管Tent和EATA已经通过在线优化获得了在OOD精度令人印象深刻的进步,但这种优化伴随着大量内存消耗这在很多现实世界的CTA应用是禁止的。因为很多设备只被设计为硬件上推断而不是训练,内存首先的设备不能负担CTA算法。大高峰内存消耗使得EARA或Tent不能在边缘设备熵部署。

观察model.backward的瓶颈,一种直接的解决方式可以是减少批尺寸和模型规模(包括通道和层数量)。但同时维护模型性能存在很多阻碍。首先大批尺寸对于适应是必要的。其次,深且宽的模型抽取的信息量对于建模分布鲁棒的语义特征是必要的。

本文方法被简单封装在一个规范层(MECTA Norm)中,以减少中间缓存的三个维度:批尺寸、通道和层尺寸。MECTA Norm从流式小批数据中积累分布知识。类似于稀疏梯度下降,本文引入了测试时剪枝,在不知道梯度大小情况下随机删除缓存的中间结果通道。遗忘门也指导层适应。如果层分布差距足够小,层将被排除在内存密集训练外。

问题定义

定义 P ( x ) P(x) P(x) 是从分布集合 P \mathbb{P} P 采样的数据分布, P 0 P_{0} P0 是训练分布。假定分布 P ∼ P P\sim \mathcal{P} PP P 0 P_{0} P0 相同或偏离较多。神经网络模型 f θ f_{\theta} fθ 在训练集上以 θ 0 = min ⁡ θ ∈ Θ E x ∼ P 0 ( x ) [ ℓ ( f θ ( x ) , c ( x ) ) ] \theta_{0}=\min_{\theta\in \Theta}\mathbb{E}_{x\sim P_{0}(x)}[\ell(f_{\theta}(x),c(x))] θ0=minθΘExP0(x)[(fθ(x),c(x))] 方式预训练。损失的一个例子是交叉熵损失。基于预训练的模型,Tent和EATA通过如下迭代更新的方式持续适应:
θ t = Opt θ ∈ Θ t ( E x ∼ P t ( x ) [ H ( f θ ( x ) ) ] , θ t − 1 ) \theta_{t}=\text{Opt}_{\theta\in \Theta_{t}}(\mathbb{E}_{x\sim P_{t}(x)}[H(f_{\theta}(x))],\theta_{t-1}) θt=OptθΘt(ExPt(x)[H(fθ(x))],θt1)
没有标签时,熵函数 H ( f θ ( x ) ) = ℓ e n t ( f θ ( x ) , f θ ( x ) ) H(f_{\theta}(x))=\ell_{ent}(f_{\theta}(x),f_{\theta}(x)) H(fθ(x))=ent(fθ(x),fθ(x)) 类似于Tent中具有自监督的交叉熵损失。

参数高效适应

因为测试时自适应高效是重要的,本文关注与CTA的SOTA高效解决方案:EATA和Tent,这两种方法都采用一步梯度下降。 Opt θ ∈ Θ t ( E H , θ t − 1 ) = θ t − 1 − η ∂ ∂ θ E H \text{Opt}_{\theta\in \Theta_{t}}(\mathbb{E}H,\theta_{t-1})=\theta_{t-1}-\eta\frac{\partial}{\partial \theta}\mathbb{E}H OptθΘt(EH,θt1)=θt1ηθEH 学习率 η \eta η。为了高效训练模型,本文约束参数空间到原始空间的一个子空间,记为 Θ t = Θ ~ ⊂ Θ \Theta_{t}=\tilde{\Theta}\subset \Theta Θt=Θ~Θ。本文算法遵循Tent和EATA使得BN层的参数可训练。

这里简单介绍BN层。定义第l层的输入为 x l x^{l} xl,维度空间为 B × C l × H l × W l B\times C^{l}\times H^{l}\times W^{l} B×Cl×Hl×Wl。BN层操作定义为两个连续的逐通道操作。

z n , i , j , k = x n , i , j , k l − μ i l σ i 2 + ϵ 0 , a n , i , j , k l = γ i l z n , i , j , k + b i l z_{n,i,j,k}=\frac{x_{n,i,j,k}^{l}-\mu_{i}^{l}}{\sqrt{\sigma_{i}^{2}+\epsilon_{0}}},\quad a_{n,i,j,k}^{l}=\gamma_{i}^{l}z_{n,i,j,k}+b_{i}^{l} zn,i,j,k=σi2+ϵ0 xn,i,j,klμil,an,i,j,kl=γilzn,i,j,k+bil
μ i l \mu_{i}^{l} μil σ i l 2 \sigma_{i}^{l^{2}} σil2 是通道i上平均值和方差。

本文方法

本文方法关注于提升基于梯度适应方法的内存效率。首先直接推导显示应该存储中间表示来计算仿射层的梯度,因此会产生巨大的内存开销。假定在n样本上损失为 ℓ n \ell_{n} n。对于第i通道上的仿射权重 γ i \gamma_{i} γi 可以表示为:
∑ n = 1 B ∂ ℓ n ∂ γ i l = ∑ n = 1 B ∑ j = 1 W ∑ k = 1 W ∂ ℓ n ∂ a i , j , k l z n , i , j , k l \sum_{n=1}^{B}\frac{\partial \ell_{n}}{\partial \gamma_{i}^{l}}=\sum_{n=1}^{B}\sum_{j=1}^{W}\sum_{k=1}^{W}\frac{\partial \ell_{n}}{\partial a_{i,j,k}^{l}}z_{n,i,j,k}^{l} n=1Bγiln=n=1Bj=1Wk=1Wai,j,klnzn,i,j,kl

为了计算梯度,每个BN层需要在前向传播中存储维度为 B × C l × W l × H l B\times C^{l}\times W^{l}\times H^{l} B×Cl×Wl×Hl 的正则化的表示 z l z^{l} zl 直到 ∂ ℓ n / ∂ a i , j , k l \partial \ell_{n}/\partial a_{i,j,k}^{l} n/ai,j,kl 可用。对于L层网络,仅推断内存消耗是 R f w d = max ⁡ l ∈ { 1 , … , L } B × C l × W l × H l R_{fwd}=\max_{l\in \{1,\ldots,L\}}B\times C^{l}\times W^{l}\times H^{l} Rfwd=maxl{1,,L}B×Cl×Wl×Hl。相对地,后向传播中对应的累计中间内存为 R b w d = ∑ l = 1 L B × C l × W l × H l ≥ R f w d R_{bwd}=\sum_{l=1}^{L}B\times C^{l}\times W^{l}\times H^{l}\geq R_{fwd} Rbwd=l=1LB×Cl×Wl×HlRfwd。为了减少内存负荷,一个直接想法是通过丢弃 { z l } l = 1 L \{z^{l}\}_{l=1}^{L} {zl}l=1L 对应的条目减少 B B B C l C^{l} Cl L L L。因为丢弃 { z l } l = 1 L \{z^{l}\}_{l=1}^{L} {zl}l=1L 中条目会消失相应的梯度,如果不小心处理该策略将很容易破坏学习过程。

减少B

由于足够数量的样本对于BN层的准确统计 ( μ \mu μ σ 2 \sigma^{2} σ2) 估计至关重要,因此减少一批数据将使得归一化的统计数据产生偏差。在标准或测试时训练中,指数移动平均(EMA)已经广泛使用于通过记忆流批数据缓解偏差。为了在校准的统计上保留梯度的特性,本文遵循Test-time batch normalization与Improving robustness against common corruptions by covariate shift adaptation.实现在测试时的EMA归一化。定义 ϕ \phi ϕ 是组合元组 [ μ , σ ] [\mu,\sigma] [μ,σ] ϕ t \phi_{t} ϕt 是第t周期的运行统计, ϕ ~ t \tilde{\phi}_{t} ϕ~t 是第t批的统计。EMA统计表示为:
ϕ t = ( 1 − β ) ϕ t − 1 + β ϕ ~ t \phi_{t}=(1-\beta)\phi_{t-1}+\beta\tilde{\phi}_{t} ϕt=(1β)ϕt1+βϕ~t

传统方法中 β \beta β 在运行时是固定的,该策略无法适应动态分布 P t P_{t} Pt 的估计。当一个模型在一个单一域 P t ≈ P t − 1 P_{t}\approx P_{t-1} PtPt1 上稳定运行, β \beta β 应该是小的以保留尽可能多数据点支持准确的统计估计。 相对地,当分布偏移 P t ≠ P t − 1 P_{t}\neq P_{t-1} Pt=Pt1 β \beta β 应该是大的以避免混合两个不同的统计。根据这个直觉,本文引入了遗忘门自适应校准 β \beta β β t = h ( ϕ t − 1 , ϕ ~ t ) \beta_{t}=h(\phi_{t-1},\tilde{\phi}_{t}) βt=h(ϕt1,ϕ~t) h ( ⋅ , ⋅ ) h(\cdot,\cdot) h(,) 考虑一个无参数的启发式定义:
β t = 1 − e − D ( ϕ t − 1 , ϕ ~ t ) , D ( ϕ t − 1 , ϕ ~ t ) = 1 C ∑ i = 1 C K L ( ϕ t − 1 , i ∣ ∣ ϕ ~ t , i ) + K L ( ϕ ~ t , i ∣ ∣ ϕ t − 1 , i ) \beta_{t}=1-e^{-D(\phi_{t-1},\tilde{\phi}_{t})},\quad D(\phi_{t-1},\tilde{\phi}_{t})=\frac{1}{C}\sum_{i=1}^{C}KL(\phi_{t-1,i}||\tilde{\phi}_{t,i})+KL(\tilde{\phi}_{t,i}||\phi_{t-1,i}) βt=1eD(ϕt1,ϕ~t),D(ϕt1,ϕ~t)=C1i=1CKL(ϕt1,i∣∣ϕ~t,i)+KL(ϕ~t,i∣∣ϕt1,i)
D ( ⋅ , ⋅ ) D(\cdot,\cdot) D(,) 是定义的测量分布偏移的距离函数。KL散度定义为 log ⁡ σ 2 i − log ⁡ σ 1 i + 1 2 σ 2 1 2 ( σ 1 i 2 + ( μ 1 − μ 2 ) 2 − 1 2 \log\sigma_{2}^{i}-\log\sigma_{1}^{i}+\frac{1}{2\sigma_{2}^{1^{2}}}(\sigma_{1}^{i^{2}}+(\mu_{1}-\mu_{2})^{2}-\frac{1}{2} logσ2ilogσ1i+2σ2121(σ1i2+(μ1μ2)221。距离函数受到Revisiting batch normalization for practical domain adaptation启发,该研究显示 ϕ 1 \phi_{1} ϕ1 ϕ 2 \phi_{2} ϕ2 在来不同域时有较大KL散度。另外 β t l \beta_{t}^{l} βtl 逐层估计,因为不同层分布将偏向不同角度。当第l-1层在纠正 ϕ t l \phi_{t}^{l} ϕtl 后良好对齐,深层应该更好对齐。

减少通道C

容易看出在 z z z 中删除通道 i i i 的缓存将使得通道处的相应梯度消失并可能使得相应的仿射参数未拟合。平凡的剪枝通道可能会带来严重的问题,特别是当一些通道对于OOD泛化至关重要时。在计算之前很难预测哪个梯度如此重要而不需要剪枝。因此需要一种不依赖于反向传播的高效剪枝策略。本文提出了一种通过每周期生成随机掩码 M M M 的无条件剪枝策略,该策略中 q × 100 % q\times 100\% q×100% 的掩码条目是0剩下是1。

本文的剪枝策略有多重优点。由于前向传播不受到影响,因此仍然可以在保留高质量语义特征的网络全尺寸上进行预测。剪枝通过更小的中间缓存和梯度显著降低了内存使用率。每次迭代重复计算的掩码类似于逐渐式学习范式,现代优化器(SGD或Adam)中动量技术可以估算缺失的梯度。而且该方法缓解了灾难性遗忘,因为只更新了仿射权重的一个子集,且不更新低幅度值参数。给定梯度 g t g_{t} gt,模型差异 ∣ ∣ θ t − θ 0 ∣ ∣ = ∣ ∣ ∑ t g t ∣ ∣ ≤ O ( ∑ t ∣ ∣ g t ∣ ∣ ) ||\theta_{t}-\theta_{0}||=||\sum_{t}g_{t}||\leq \mathcal{O}(\sum_{t}||g_{t}||) ∣∣θtθ0∣∣=∣∣tgt∣∣O(t∣∣gt∣∣) 将因为一些 g t g_{t} gt 中的0减少,这可以被是为EATA中隐式反遗忘正则项。

动态L

大多数测试时自适应将在单个环境中持续很长时间。长时间持续将模型适应同一环境不会持续改进模型,反而会浪费资源。 因此如果优化收敛,本文建议停止反向传播,并在需要自适应时重新启动。原则是当数据分布发生根本性变化时需要进行调整。回顾已经通过 β t = 1 − e − D ( ϕ t − 1 , ϕ ~ t ) \beta_{t}=1-e^{-D(\phi_{t-1},\tilde{\phi}_{t})} βt=1eD(ϕt1,ϕ~t) 测量了分布偏移,这里重用该度量来指导适应。使用阈值 β t h \beta_{th} βth 做决定。如果 β t l ≤ θ t h \beta_{t}^{l}\leq \theta_{th} βtlθth z l z^{l} zl 缓存用于后向传播。因为分层决策,任何层可以在网络完全执行或所有层优化收敛之前提前停止训练以节省大量内存。该策略引入了动态内存分配,使得多任务移动系统获益。

与先前内存约简方法的关键区别

先前研究Sparcl: Sparse continual learning on the edge与Mest: Accurate and fast memory-economic sparse training framework on the edge关注于参数稀疏度并减少存储模型参数和梯度的负担。然而参数或梯度的负担对于后传的缓存时相对小的。本文的通道剪枝和按需要层训练类似于梯度稀疏或坐标下降。这两种策略在反向传播之前隐式剪枝梯度避免了大型缓存。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

qgh1223

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值