持续学习 – 深入探讨弹性权重巩固损失

原文:towardsdatascience.com/continual-learning-a-deep-dive-into-elastic-weight-consolidation-loss-7cda4a2d058c

简介

训练人工神经网络时面临的最具挑战性的问题之一是灾难性遗忘。当神经网络在一个任务(任务 A)上训练后,随后学习一个新的任务(任务 B),在这个过程中,忘记了如何执行原始任务时,就会出现这个问题。在本文中,我们将探讨一种解决此问题的方法,称为弹性权重巩固(EWC)。EWC 提供了一种有前景的方法来减轻灾难性遗忘,使神经网络能够在学习新技能的同时保留先前学习任务的知识。

本文中所有图表均由作者提供,除非另有说明

直觉

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/a299a9f94f5b24030968e7cda2250253.png

图 1:EWC 的直觉,来自论文

已经证明,存在许多配置的最佳参数,在任务上具有所需的低误差 – 上图中任务 A 和 B 的灰色和黄色区域分别表示。假设我们找到了这样一个配置 θꭺ 用于任务 A,当继续从这种配置训练模型到新任务 B 时,我们面临三种不同的场景:

  1. 简单地继续在任务 B 上训练而不施加惩罚,最终会导致任务 B 的低水平区域,但在任务 A 上的表现低于期望的准确性 – 蓝色箭头。

  2. 使用任务 A 的权重 L2 约束可能过于强烈,使得模型在任务 A 上表现良好,但在任务 B 上欠拟合 – 绿色箭头。

  3. EWC 是提出的解决方案,将在模型在两个任务上都表现良好的区域找到参数(两个区域的交集) – 红色箭头。

接下来,我们将解释这是如何完成的。

理论

鱼际信息矩阵(FIM)

让我们先了解 EWC 方法所基于的 FIM(Fisher 信息矩阵)。FIM 是一种统计量度,它量化了给定数据为我们提供关于我们旨在估计的未知参数 θ 的信息量。在持续学习的背景下,FIM 将有助于识别神经网络参数,这些参数对先前任务的数据提供的信息较少。通过更新这些参数,网络可以在不删除存储在参数中的重要信息的情况下学习新任务,这些参数对先前学习任务的信息非常丰富。

为了更加正式,假设 X 是一个随机变量,其概率密度函数由 θ 参数化,表示为 f(x|θ)。对于样本 x 的似然函数,它仅是参数的函数,数据保持固定,可以表示为:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f41c793fbba09b91a19ca3ad9c062bf0.png

图 2:似然函数 FIM

和对数似然:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/a31d63c770f34197c1475b22fe8530dd.png

图 3:对数似然函数 FIM

我们现在可以定义 FIM 为:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/28a7d6c5bda92226a9d5530dbeea165c.png

图 4:FIM 的定义(a)

这表明对数似然函数对参数的小幅度变化有多敏感。一个等效的定义,稍后将会用到,是我们可以将 FIM 看作似然函数二阶导数的负期望:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b3e77418e1abd3d4521a4a0f8d3ff439.png

图 5:FIM 的定义(b)

在求二阶导数时,我们基本上是在观察似然函数的曲率。

为了说明这一点,考虑下面的图表,其中绘制了两个似然函数。蓝色曲线代表一个在其峰值周围非常狭窄的分布,表明我们的数据在θ附近的可能性很大,并且当我们远离它时迅速下降。相比之下,黑色曲线代表一个更宽泛的分布,其中数据即使在远离θ时也保持相似的可能性。FIM 量化了这个概念——我们的数据如何紧密地约束在θ的某个特定值上。大的 FIM(如蓝色曲线所示)意味着参数值的小幅变化将导致在这些参数下数据似然度的显著下降。相反,小的 FIM(如黑色曲线所示)意味着参数值的小幅变化将导致似然度的小幅降低。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ab0ad3663dc5cbc9a775a44f3734135e.png

图 6:FIM 的视觉说明

结果表明,Fisher 信息矩阵与数据的方差(或在多元情况下的协方差)成反比。在上面的图表中,如果我们假设曲线代表两个具有均值 _θ_₀和方差σ²ᵦₗᵤₑ和σ²ᵦₗₐ꜀ₖ的 Gaussian 分布,其中σ²ᵦₗᵤₑ < σ²ᵦₗₐ꜀ₖ,FIM 等于 1/σ²,因此蓝色曲线包含更多的信息。

弹性权重巩固

给定数据D和一个具有参数θ的神经网络,我们的目标是最大化给定数据的参数概率p(θ|D)。根据贝叶斯定理,我们有:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/42737e17f4a4c2e60d0c9f4c5e015ec5.png

图 7:最大化目标

将对数应用于等式的两边不会改变最大化目标,因为对数是一个单调变换。因此,我们的目标变为:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/c2dd674df2d571e323a41fa4b2f79335.png

图 8:对数最大化目标

假设有两个独立任务D = {A, B},我们有:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/9ee426278584dc5bae18bdf2ab997a3d.png

图 9:两个任务的对数最大化目标

最后一个结果来自 A 和 B 的独立性。在这里,log(p(B|θ)) 是任务 B 的损失,log(p(B)) 是 B 的似然,对于我们的优化来说可以被视为常数,因为它不依赖于 θ,而 log(p(θ|A)) 是任务 A 的后验分布,包含有关哪些参数对任务 A 重要的所有信息。

估计 log(p(θ|A)) 是不可行的,因为计算它将涉及到在整个参数空间上对高维函数进行积分。因此,它被一个具有均值 A 任务最优参数 – θꭺ 和方差 Fisher 信息矩阵的正态分布所近似。这种近似是有意义的,因为我们假设 A 和 B 任务的新参数 θ 将不会远离 A 任务的最优参数。此外,在 θꭺ 的所有参数中,将有一些参数对于 A 任务的良好性能更为重要,我们不想让它们变化太大,这就是 FIM 发挥作用的地方,其值,如前所述,表示改变某个参数将如何影响此案例中 A 任务的损失,因此具有更高 FIM 值的参数变化会受到更多惩罚。

现在,让我们对任务 A 的最优权重进行到二阶项的泰勒展开:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f6462b42ab241837b12b904212cdea44.png

图 10:围绕 A 任务最优参数的泰勒展开

在优化过程中,log(p(θꭺ|A))* 是一个可以忽略的常数。我们也可以忽略第二个项,因为在最优的 θꭺ 处梯度为零。现在我们已经找到了 log(p(θ|A)) 的表达式,让我们将其放回图 8 中的原始公式中:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e78051bb5e2190c7e7e26e6e692bb414.png

图 11:具有两个任务的扩展对数最大化目标

其中第二个项的二阶导数是 Hessian,可以根据图 5 中的定义用 Fisher 信息矩阵来近似。log(p(B|θ)) 是新任务 B 的损失,例如交叉熵,我们用 Lᵦ(θ) 表示。

便利的是,我们不需要执行二阶导数,可以根据图 4 中的定义仅使用一阶导数来近似 FIM,这相当于图 5 中的定义,即对数似然梯度的外积:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ddd17dce774145437c7c960eaa0b94f9.png

图 12:Fisher 定义为梯度之间的向量积

最后,要优化的整体损失 L(θ) 变为:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/a395c9431f928ef467ca6c1d988d86e5.png

图 13:最终损失

其中λ是一个超参数,表示保持先前任务 A 的准确性的重要性。

旁注:上述涉及梯度向量外积的定义捕捉了梯度的协方差结构。然而,通常使用 FIM 的对角线近似,即梯度的平方,它只计算参数的方差,但计算成本较低,对于该任务来说是足够的:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/de63a85c89449489bcac89752510c8c6.png

图 14:Fisher 近似

实现方法

让我们先导入一些库以及代表任务 A 和 B 的 MNIST([商业使用许可](https://github.com/sharmaroshan/MNIST-Dataset/blob/master/LICENSE))和 Fashion MNIST(商业使用许可)数据集。我们还定义了一个简单的神经网络:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd
import numpy as np
from torch.utils.data import DataLoader

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

def get_accuracy(model, dataloader):
    model = model.eval()
    acc = 0
    for input, target in dataloader:
        o = model(input.to(device))
        acc += (o.argmax(dim=1).long() == target.to(device)).float().mean()
    return acc / len(dataloader)

class LinearLayer(nn.Module):
    # from https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
    def __init__(self, input_dim, output_dim, act='relu', use_bn=False):
        super(LinearLayer, self).__init__()
        self.use_bn = use_bn
        self.lin = nn.Linear(input_dim, output_dim)
        self.act = nn.ReLU() if act == 'relu' else act
        if use_bn:
            self.bn = nn.BatchNorm1d(output_dim)
    def forward(self, x):
        if self.use_bn:
            return self.bn(self.act(self.lin(x)))
        return self.act(self.lin(x))

class Flatten(nn.Module):

    def forward(self, x):
        return x.view(x.shape[0], -1)

class Model(nn.Module):

    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(Model, self).__init__()
        self.f1 = Flatten()
        self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=True)
        self.lin2 = LinearLayer(num_hidden, num_hidden, use_bn=True)
        self.lin3 = nn.Linear(num_hidden, num_outputs)

    def forward(self, x):
        return self.lin3(self.lin2(self.lin1(self.f1(x))))

# Load MNIST dataset, representint task A
mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

# FashiomMNIST is task B
f_mnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)

现在我们将在 MNIST 任务上训练模型:

# parameters
EPOCHS = 4
lr=0.001
weight=100000 
accuracies = {}

device = 'cuda:1'

criterion = nn.CrossEntropyLoss()

# train model on task A 
model = Model(28 * 28, 100, 10).to(device)
optimizer = optim.Adam(model.parameters(), lr)

for _ in range(EPOCHS):
    for input, target in tqdm(train_loader):
        output = model(input.to(device))
        loss = criterion(output, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

accuracies['mnist_initial'] = get_accuracy(model, test_loader)

现在我们定义了估计 FIM 和 EWC 损失中使用的先前参数的函数:

def ewc_loss(model, weight, estimated_fishers, estimated_means):
    losses = []
    for param_name, param in model.named_parameters():
        estimated_mean = estimated_means[param_name]
        estimated_fisher = estimated_fishers[param_name]
        losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())

    return (weight / 2) * sum(losses)

def estimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
    estimated_mean = {}

    for param_name, param in model.named_parameters():
        estimated_mean[param_name] = param.data.clone()

    estimated_fisher = {}
    dl = DataLoader(train_ds, batch_size, shuffle=True)

    for n, p in model.named_parameters():
        estimated_fisher[n] = torch.zeros_like(p)

    model.eval()
    for i, (input, target) in enumerate(dl):
        if i > num_batch:
            break
        model.zero_grad()

        output = model(input.to(device))
        # https://www.inference.vc/on-empirical-fisher-information/ - more on this here
        if ESTIMATE_TYPE == 'empirical':
            # empirical
            label = target.to(device)
        else:
            # true estimate
            label = output.max(1)[1]

        loss = F.nll_loss(F.log_softmax(output, dim=1), label)
        loss.backward()

        # accumulate all the gradients
        for n, p in model.named_parameters():
            estimated_fisher[n].data += p.grad.data ** 2 / len(dl)

    estimated_fisher = {n: p for n, p in estimated_fisher.items()}
    return estimated_mean, estimated_fisher

最后,让我们继续使用 EWC 损失在任务 B 上训练网络:

# compute fisher and mean parameters for EWC loss
estimated_mean, estimated_fisher = estimate_ewc_params(model, mnist_train)

# Train task B fashion mnist
for _ in range(EPOCHS):
    for input, target in tqdm(f_train_loader):
        output = model(input.to(device))
        loss = ewc_loss(model, weight, estimated_fisher, estimated_mean) + criterion(output, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

accuracies['mnist_EWC'] = get_accuracy(model, test_loader)
accuracies['f_mnist_EWC'] = get_accuracy(model, f_test_loader)

我们得到了以下准确率:

{'mnist_initial': tensor(0.9772, device='cuda:1'),
 'mnist_AB': tensor(0.9717, device='cuda:1'),
 'f_mnist': tensor(0.8312, device='cuda:1')}

我们可以将这些结果与没有 EWC 损失的模型运行结果进行比较:

{'mnist_initial': tensor(0.9762, device='cuda:1'),
 'mnist_AB': tensor(0.1769, device='cuda:1'),
 'f_mnist': tensor(0.8672, device='cuda:1')}

比较这两个结果,我们可以看到 EWC 损失有助于在执行任务 B 的同时,几乎保持任务 A 的准确率不变,并且几乎达到了没有 EWC 损失时的相同准确率水平。

结论

这篇文章中,我们看到了一种允许神经网络在继续学习新任务的同时保留之前学习到的知识的技术。还有其他方法可以实现模型的持续学习,例如基于重放的那些方法,它们存储先前数据的一个子集,并在学习新任务时重新播放它以防止遗忘,元学习方法,参数隔离方法,将模型参数的不同子集分配给不同的任务以避免干扰,以及其他方法。感兴趣的读者可以参考这篇综述论文以获得更广泛的讨论。

参考文献

1612.00796 (arxiv.org) www.youtube.com/watch?v=82molmnRCg0 abhishekaich27.github.io/data/WriteUps/EWC_nuts_and_bolts.pdf github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py

### 持续学习的概念及其在机器学习和软件开发中的应用 #### 什么是持续学习持续学习是一种让模型能够随着时间推移不断适应新数据的能力,而不会忘记之前学到的知识。这种能力对于动态环境下的应用程序尤为重要,在这些环境中,数据分布可能会发生变化[^1]。 #### 持续学习的方法 为了实现持续学习,研究者们提出了多种略和技术,主要包括以下几种: 1. **正则化方法** 正则化方法通过惩罚参数更新的方式防止模型遗忘旧知识。例如,“弹性权重巩固”(Elastic Weight Consolidation, EWC)通过对重要参数施加约束来保护先前任务的学习成果[^2]。 2. **重放机制** 这种方法的核心思想是从过去的数据中采样并将其与当前的新数据混合训练。这种方法可以分为两种形式:实际经验回放(Experience Replay),即存储一部分历史数据;以及生成对抗网络辅助的经验回放(Generative Replays),其中使用GAN生成模拟的历史样本[^3]。 3. **模块化架构** 使用独立子网络分别处理不同任务,并通过共享部分结构保持泛化性能。这样设计的好处在于新增任务时只需扩展特定组件而不影响已有功能[^4]。 #### 应用于机器学习领域 在网络安全场景下,持续学习可以帮助构建更加鲁棒的威胁检测系统。随着新型攻击手段层出不穷,传统静态模型难以应对快速变化的风险形势。采用支持增量式学习框架,则可使防护措施紧跟最新趋势,从而更有效地抵御未知类型的入侵行为[^5]。 另外,在视频监控方面,当摄像头覆盖范围扩大或者光照条件改变等因素引起图像特征漂移时,具备在线调整能力的算法显得尤为必要。借助于上述提到的技术路径之一——渐进微调(Progressive Fine-tuning),可以在不重新初始化整个神经网络的前提下完成迁移适配过程[^6]。 #### 软件工程实践中的体现 从软件开发生命周期角度来看,敏捷迭代模式本身就蕴含着某种意义上的“持续改进”。开发者团队基于用户反馈循环优化产品特性直至满足需求为止的过程与此处探讨的主题异曲同工。具体而言: - 自动测试套件配合CI/CD流水线确保每次提交代码改动之后仍能维持既有质量标准; - A/B 测试允许企业在真实生产环境下评估候选版本表现差异进而做出科学决。 ```python def simulate_continuous_learning(model, new_data): """ Simulate a simple form of continuous learning by fine-tuning an existing model with new data. Args: model (object): Pre-trained machine learning model instance. new_data (list or array-like): New dataset used to update the knowledge base. Returns: object: Updated version after incorporating additional information. """ updated_model = model.fit(new_data) # Assuming 'fit' method exists and supports incremental updates return updated_model ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值