简介
训练人工神经网络时面临的最具挑战性的问题之一是灾难性遗忘。当神经网络在一个任务(任务 A)上训练后,随后学习一个新的任务(任务 B),在这个过程中,忘记了如何执行原始任务时,就会出现这个问题。在本文中,我们将探讨一种解决此问题的方法,称为弹性权重巩固(EWC)。EWC 提供了一种有前景的方法来减轻灾难性遗忘,使神经网络能够在学习新技能的同时保留先前学习任务的知识。
本文中所有图表均由作者提供,除非另有说明
直觉
图 1:EWC 的直觉,来自论文
已经证明,存在许多配置的最佳参数,在任务上具有所需的低误差 – 上图中任务 A 和 B 的灰色和黄色区域分别表示。假设我们找到了这样一个配置 θꭺ 用于任务 A,当继续从这种配置训练模型到新任务 B 时,我们面临三种不同的场景:
-
简单地继续在任务 B 上训练而不施加惩罚,最终会导致任务 B 的低水平区域,但在任务 A 上的表现低于期望的准确性 – 蓝色箭头。
-
使用任务 A 的权重 L2 约束可能过于强烈,使得模型在任务 A 上表现良好,但在任务 B 上欠拟合 – 绿色箭头。
-
EWC 是提出的解决方案,将在模型在两个任务上都表现良好的区域找到参数(两个区域的交集) – 红色箭头。
接下来,我们将解释这是如何完成的。
理论
鱼际信息矩阵(FIM)
让我们先了解 EWC 方法所基于的 FIM(Fisher 信息矩阵)。FIM 是一种统计量度,它量化了给定数据为我们提供关于我们旨在估计的未知参数 θ 的信息量。在持续学习的背景下,FIM 将有助于识别神经网络参数,这些参数对先前任务的数据提供的信息较少。通过更新这些参数,网络可以在不删除存储在参数中的重要信息的情况下学习新任务,这些参数对先前学习任务的信息非常丰富。
为了更加正式,假设 X 是一个随机变量,其概率密度函数由 θ 参数化,表示为 f(x|θ)。对于样本 x 的似然函数,它仅是参数的函数,数据保持固定,可以表示为:
图 2:似然函数 FIM
和对数似然:
图 3:对数似然函数 FIM
我们现在可以定义 FIM 为:
图 4:FIM 的定义(a)
这表明对数似然函数对参数的小幅度变化有多敏感。一个等效的定义,稍后将会用到,是我们可以将 FIM 看作似然函数二阶导数的负期望:
图 5:FIM 的定义(b)
在求二阶导数时,我们基本上是在观察似然函数的曲率。
为了说明这一点,考虑下面的图表,其中绘制了两个似然函数。蓝色曲线代表一个在其峰值周围非常狭窄的分布,表明我们的数据在θ附近的可能性很大,并且当我们远离它时迅速下降。相比之下,黑色曲线代表一个更宽泛的分布,其中数据即使在远离θ时也保持相似的可能性。FIM 量化了这个概念——我们的数据如何紧密地约束在θ的某个特定值上。大的 FIM(如蓝色曲线所示)意味着参数值的小幅变化将导致在这些参数下数据似然度的显著下降。相反,小的 FIM(如黑色曲线所示)意味着参数值的小幅变化将导致似然度的小幅降低。
图 6:FIM 的视觉说明
结果表明,Fisher 信息矩阵与数据的方差(或在多元情况下的协方差)成反比。在上面的图表中,如果我们假设曲线代表两个具有均值 _θ_₀和方差σ²ᵦₗᵤₑ和σ²ᵦₗₐ꜀ₖ的 Gaussian 分布,其中σ²ᵦₗᵤₑ < σ²ᵦₗₐ꜀ₖ,FIM 等于 1/σ²,因此蓝色曲线包含更多的信息。
弹性权重巩固
给定数据D和一个具有参数θ的神经网络,我们的目标是最大化给定数据的参数概率p(θ|D)。根据贝叶斯定理,我们有:
图 7:最大化目标
将对数应用于等式的两边不会改变最大化目标,因为对数是一个单调变换。因此,我们的目标变为:
图 8:对数最大化目标
假设有两个独立任务D = {A, B},我们有:
图 9:两个任务的对数最大化目标
最后一个结果来自 A 和 B 的独立性。在这里,log(p(B|θ)) 是任务 B 的损失,log(p(B)) 是 B 的似然,对于我们的优化来说可以被视为常数,因为它不依赖于 θ,而 log(p(θ|A)) 是任务 A 的后验分布,包含有关哪些参数对任务 A 重要的所有信息。
估计 log(p(θ|A)) 是不可行的,因为计算它将涉及到在整个参数空间上对高维函数进行积分。因此,它被一个具有均值 A 任务最优参数 – θꭺ 和方差 Fisher 信息矩阵的正态分布所近似。这种近似是有意义的,因为我们假设 A 和 B 任务的新参数 θ 将不会远离 A 任务的最优参数。此外,在 θꭺ 的所有参数中,将有一些参数对于 A 任务的良好性能更为重要,我们不想让它们变化太大,这就是 FIM 发挥作用的地方,其值,如前所述,表示改变某个参数将如何影响此案例中 A 任务的损失,因此具有更高 FIM 值的参数变化会受到更多惩罚。
现在,让我们对任务 A 的最优权重进行到二阶项的泰勒展开:
图 10:围绕 A 任务最优参数的泰勒展开
在优化过程中,log(p(θꭺ|A))* 是一个可以忽略的常数。我们也可以忽略第二个项,因为在最优的 θꭺ 处梯度为零。现在我们已经找到了 log(p(θ|A)) 的表达式,让我们将其放回图 8 中的原始公式中:
图 11:具有两个任务的扩展对数最大化目标
其中第二个项的二阶导数是 Hessian,可以根据图 5 中的定义用 Fisher 信息矩阵来近似。log(p(B|θ)) 是新任务 B 的损失,例如交叉熵,我们用 Lᵦ(θ) 表示。
便利的是,我们不需要执行二阶导数,可以根据图 4 中的定义仅使用一阶导数来近似 FIM,这相当于图 5 中的定义,即对数似然梯度的外积:
图 12:Fisher 定义为梯度之间的向量积
最后,要优化的整体损失 L(θ) 变为:
图 13:最终损失
其中λ是一个超参数,表示保持先前任务 A 的准确性的重要性。
旁注:上述涉及梯度向量外积的定义捕捉了梯度的协方差结构。然而,通常使用 FIM 的对角线近似,即梯度的平方,它只计算参数的方差,但计算成本较低,对于该任务来说是足够的:
图 14:Fisher 近似
实现方法
让我们先导入一些库以及代表任务 A 和 B 的 MNIST([商业使用许可](https://github.com/sharmaroshan/MNIST-Dataset/blob/master/LICENSE))和 Fashion MNIST(商业使用许可)数据集。我们还定义了一个简单的神经网络:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
def get_accuracy(model, dataloader):
model = model.eval()
acc = 0
for input, target in dataloader:
o = model(input.to(device))
acc += (o.argmax(dim=1).long() == target.to(device)).float().mean()
return acc / len(dataloader)
class LinearLayer(nn.Module):
# from https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
def __init__(self, input_dim, output_dim, act='relu', use_bn=False):
super(LinearLayer, self).__init__()
self.use_bn = use_bn
self.lin = nn.Linear(input_dim, output_dim)
self.act = nn.ReLU() if act == 'relu' else act
if use_bn:
self.bn = nn.BatchNorm1d(output_dim)
def forward(self, x):
if self.use_bn:
return self.bn(self.act(self.lin(x)))
return self.act(self.lin(x))
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)
class Model(nn.Module):
def __init__(self, num_inputs, num_hidden, num_outputs):
super(Model, self).__init__()
self.f1 = Flatten()
self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=True)
self.lin2 = LinearLayer(num_hidden, num_hidden, use_bn=True)
self.lin3 = nn.Linear(num_hidden, num_outputs)
def forward(self, x):
return self.lin3(self.lin2(self.lin1(self.f1(x))))
# Load MNIST dataset, representint task A
mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
# FashiomMNIST is task B
f_mnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)
现在我们将在 MNIST 任务上训练模型:
# parameters
EPOCHS = 4
lr=0.001
weight=100000
accuracies = {}
device = 'cuda:1'
criterion = nn.CrossEntropyLoss()
# train model on task A
model = Model(28 * 28, 100, 10).to(device)
optimizer = optim.Adam(model.parameters(), lr)
for _ in range(EPOCHS):
for input, target in tqdm(train_loader):
output = model(input.to(device))
loss = criterion(output, target.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracies['mnist_initial'] = get_accuracy(model, test_loader)
现在我们定义了估计 FIM 和 EWC 损失中使用的先前参数的函数:
def ewc_loss(model, weight, estimated_fishers, estimated_means):
losses = []
for param_name, param in model.named_parameters():
estimated_mean = estimated_means[param_name]
estimated_fisher = estimated_fishers[param_name]
losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
return (weight / 2) * sum(losses)
def estimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
estimated_mean = {}
for param_name, param in model.named_parameters():
estimated_mean[param_name] = param.data.clone()
estimated_fisher = {}
dl = DataLoader(train_ds, batch_size, shuffle=True)
for n, p in model.named_parameters():
estimated_fisher[n] = torch.zeros_like(p)
model.eval()
for i, (input, target) in enumerate(dl):
if i > num_batch:
break
model.zero_grad()
output = model(input.to(device))
# https://www.inference.vc/on-empirical-fisher-information/ - more on this here
if ESTIMATE_TYPE == 'empirical':
# empirical
label = target.to(device)
else:
# true estimate
label = output.max(1)[1]
loss = F.nll_loss(F.log_softmax(output, dim=1), label)
loss.backward()
# accumulate all the gradients
for n, p in model.named_parameters():
estimated_fisher[n].data += p.grad.data ** 2 / len(dl)
estimated_fisher = {n: p for n, p in estimated_fisher.items()}
return estimated_mean, estimated_fisher
最后,让我们继续使用 EWC 损失在任务 B 上训练网络:
# compute fisher and mean parameters for EWC loss
estimated_mean, estimated_fisher = estimate_ewc_params(model, mnist_train)
# Train task B fashion mnist
for _ in range(EPOCHS):
for input, target in tqdm(f_train_loader):
output = model(input.to(device))
loss = ewc_loss(model, weight, estimated_fisher, estimated_mean) + criterion(output, target.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracies['mnist_EWC'] = get_accuracy(model, test_loader)
accuracies['f_mnist_EWC'] = get_accuracy(model, f_test_loader)
我们得到了以下准确率:
{'mnist_initial': tensor(0.9772, device='cuda:1'),
'mnist_AB': tensor(0.9717, device='cuda:1'),
'f_mnist': tensor(0.8312, device='cuda:1')}
我们可以将这些结果与没有 EWC 损失的模型运行结果进行比较:
{'mnist_initial': tensor(0.9762, device='cuda:1'),
'mnist_AB': tensor(0.1769, device='cuda:1'),
'f_mnist': tensor(0.8672, device='cuda:1')}
比较这两个结果,我们可以看到 EWC 损失有助于在执行任务 B 的同时,几乎保持任务 A 的准确率不变,并且几乎达到了没有 EWC 损失时的相同准确率水平。
结论
在这篇文章中,我们看到了一种允许神经网络在继续学习新任务的同时保留之前学习到的知识的技术。还有其他方法可以实现模型的持续学习,例如基于重放的那些方法,它们存储先前数据的一个子集,并在学习新任务时重新播放它以防止遗忘,元学习方法,参数隔离方法,将模型参数的不同子集分配给不同的任务以避免干扰,以及其他方法。感兴趣的读者可以参考这篇综述论文以获得更广泛的讨论。
参考文献
1612.00796 (arxiv.org) www.youtube.com/watch?v=82molmnRCg0 abhishekaich27.github.io/data/WriteUps/EWC_nuts_and_bolts.pdf github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
296

被折叠的 条评论
为什么被折叠?



