门控循环单元(GRU)– 改进 RNN

原文:towardsdatascience.com/gated-recurrent-units-gru-improving-rnns-4467b66c5150

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/a467db9f3904a9b710f80476eecd982d.png

www.flaticon.com/free-icons/neural-network” title=”neural network icons”>由 juicy_fish 创建的神经网络图标 – Flaticon。


在这篇文章中,我将探讨循环神经网络(RNNs)的标准实现:门控循环单元(GRUs)。

GRU 由 Kyunghyun Cho 等人于 2014 年提出,是 vanilla RNN 的改进,因为它们较少受到梯度消失问题的影响,这使得它们能够拥有更长的记忆。

它们类似于长短期记忆(LSTM)网络,但操作更少,这使得它们更节省内存。

我们将涵盖的内容:

  • RNN 概述

  • 消失和爆炸梯度问题

  • LSTM 概述

  • GRU 的工作原理

循环神经网络

循环神经网络是一种特别擅长处理基于序列的数据的神经网络,例如自然语言和时间序列。

它们通过添加一个“循环”神经元来实现这一点,允许信息从前一个输入和输出传递到下一步。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f78144a8032fae629649666afeae0bea.png

显示隐藏状态的 RNN 的示例架构。图由作者绘制。

左边是一个 RNN,右边是相同但随时间展开的 RNN。

注意 _ 前一步的执行是如何传递到下一步的。这为系统添加了“记忆”,有助于模型识别历史模式。这种记忆是通过传递隐藏状态, h_ *来实现的。

例如,对于 _Y_1*,循环神经元使用* X_1 的输入和前一个时间步的输出, Y_0*。这意味着* Y_0 间接影响了 Y_2

对于循环神经元的每个隐藏状态都可以计算如下:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/186edfa18c847b8976746406aca4d854.png

循环神经元隐藏状态方程。由作者在 LaTeX 中创建。

其中:

  • h_t​ 是时间 t 的隐藏状态。

  • h{t−1​}_ 是前一个时间步的隐藏状态。

  • x_t​ 是时间 t 的输入数据。

  • _W_h_​ 是隐藏状态的权重矩阵。

  • _W_x *是输入数据的权重矩阵。

  • b_h 是隐藏状态的偏置向量。

  • σ激活函数,通常是 tanh 或 sigmoid。

上述方程是常规标准循环单元,但当我们看到 LSTMs 和 GRUs 时,事情会变得稍微复杂一些。

如果您想深入了解 RNN,并查看背后的手写示例,请查看我的上一篇文章。

循环神经网络 – 序列建模简介

梯度爆炸与消失问题

时间反向传播

RNN 通过一个称为[反向传播通过时间(BPTT)](https://medium.com/towards-data-science/backpropagation-through-time-how-rnns-learn-e5bc03ad1f0a?sk=77ae1815bf6a5066c1e4790128ebf9f2)的算法进行学习,这是用于训练常规神经网络的常规反向传播算法的一个变体。

反向传播的目标是计算损失函数作为每个参数的偏导数,然后使用梯度下降更新其值。

以这个 RNN 为例:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/c9013ece0944c463c17ba6ae40ee58f8.png

带有关联权重矩阵的 RNN。由作者绘制的图。

这里我们有:

  • Y 是输出向量

  • X 是特征的输入向量

  • h 是隐藏状态

  • V 是输出的加权矩阵

  • U 是输入的加权矩阵

  • W 是隐藏状态的加权矩阵

在给定的时间 t,计算出的输出是:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/72850d43f21d6b5a757a4313fc2704a8.png

通用 RNN 方程。由作者使用 LaTeX 生成。

其中 σ 是一个 激活函数,通常是 sigmoid

假设我们的损失函数是均方误差:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/bea866e869eb4a3ac26fa179f7b09bd0.png

RNN 的损失函数。由作者使用 LaTeX 生成。

其中 A_t 是实际值。

我们需要计算以下内容来使用第 3 个时间步长的误差更新矩阵 W。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/a5236f9f3f94e461499061d58225249e.png

由作者使用 LaTeX 生成的方程。

表达式中的第一项相对简单 。 E_3Y_3** 的函数,Y_3* 是 h_3** 的函数,*h**3 是 W_ 中的一个元素。

然而,我们不能就此止步。矩阵 W 也用于之前的步骤来计算 h_2h_1,因此我们还需要对这些较早的时间步长进行微分。

通常,我们必须在状态 h_3 之前考虑所有时间步的 W 的影响。然而,这引入了一个问题,因为我们现在有一个很长的计算链来传递我们的梯度。

如果你想了解时间反向传播的完整解析,请查看我的上一篇文章。

时间反向传播 – RNN 如何学习

梯度问题

如果我们有很多时间步,求和项的数量将会增加。当这种情况发生时,对于早期的时间步,传播的梯度会逐渐消失。这是因为我们多次乘以一个介于 0 和 1 之间的数。

在 BPTT 的情况下,我们重复计算隐藏状态相对于前一步隐藏状态的偏导数,就像我们上面展示的那样。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/c1eec3b28b3b8103a04861db3a2f929c.png

作者用 LaTeX 生成的方程式。

所以发生的情况是:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/1b943604cf42c3754e4b3df937978c05.png

作者用 LaTeX 生成的方程式。

你现在可以看到为什么 RNN 会经历梯度消失和爆炸。如上所示,隐藏状态的偏导数乘积的链式法则效应导致梯度相对于时间步的数量呈指数级消失或爆炸。

那么我们该怎么办呢?

LSTM

概述

长短期记忆(LSTM)网络通过引入称为门的元素来克服梯度爆炸和消失的问题,这些门控制着数据在循环细胞中的流入和流出。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/9109aa0d18e2f05c3565cf05cd7cf3cc.png

LSTM RNN 细胞示例。图由作者绘制。

一个需要提到的关键特性是细胞状态,C. 这个细胞状态包含有关上下文和历史模式的信息,基本上是记忆。它的目的是保存整个网络的记忆,而隐藏状态只是用于输出和短期依赖。

有四个主要部分需要理解。

遗忘门

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/6319254e6d923fb19722e2c859f8aeb4.png

作者用 LaTeX 编写的遗忘门方程。

这个门决定从先前的细胞状态 C{t-1}_ 中移除哪些旧信息。

位置

  • σ: sigmoid 激活函数。

  • _Wf​: 遗忘门的权重矩阵。

  • h{t−1}_:前一个时间步的输出。

  • _Xt​: 时间步 t 的输入。

  • _bf​: 遗忘门的偏置。

  • f_t: 遗忘门输出,其值介于 0 和 1 之间。

  • X_t: 当前输入。

输入门与候选细胞状态

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/bd07d94ce953389092f2ef5a4f0c5222.png

作者用 LaTeX 编写的输入门方程。

接下来是输入门,i_t,它决定了在当前时间戳的细胞状态中要添加哪些新内容。候选细胞状态,* Ct**_,包含了我们可以添加到细胞状态中的所有潜在信息。

其中:

  • W_i:输入门的权重矩阵。

  • W_c:候选细胞状态的权重矩阵。

  • _X*t_:时间步t的输入。

  • h{t−1}_:前一个时间步的输出。

  • _b*i_:输入门的偏置

  • b_c:候选细胞状态的偏置。

  • _Ct**:**候选细胞状态,输出值介于-1 和 1 之间。

  • i_t:输入门输出,介于 0 和 1 之间。

  • σ: sigmoid 激活函数。

  • tanh: tanh 激活函数。

细胞状态更新

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/70c4db13a0a07cdd9e9e9c2a1e7d5e6b.png

作者在 LaTeX 中给出的细胞状态更新方程。

为了只添加候选细胞状态中最重要的信息,我们将候选细胞状态与输入门相乘,并将其加到遗忘门和前一个细胞状态的乘积上。

输出门

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/fec279d7c66efb6127202f027317fb61.png

作者在 LaTeX 中给出的输出门方程。

最后一步是决定从细胞中输出什么。第一部分计算输出门,o_t,它决定了我们将暴露哪些细胞状态的部分。这基本上是我们在之前看到的正常 RNN 中隐藏状态循环单元的计算。

这个输出,o_t,然后乘以应用 tanh 函数到新细胞状态的结果。这确保我们只使用我们想要输出的值进行预测。

其中:

  • h{t−1}_:前一个时间步的输出。

  • _X*t_:时间步t的输入。

  • W_o:输出门的权重矩阵。

  • b_o:偏置。

  • o_t:输出状态。

  • h_t:新的隐藏状态。

  • C_t:新的细胞状态。

  • σ: sigmoid 激活函数。

  • tanh: tanh 激活函数。

概述

LSTMs 使用门来控制网络中的数据流,以确保只有相关信息保留。像大多数网络一样,魔法发生在通过时间反向传播训练各种权重矩阵,以便它们可以学习忘记或删除哪些部分。这种权重的优化是通过时间反向传播完成的。

门控循环单元

概览

我们最终来到了 GRUs,它们与 LSTMs 相似。它们的目的是通过允许新的网络更好地建立更广泛的记忆上下文来实现类似的功能。

它们的门控数比 LSTMs 少;因此,它们略微简化了计算并使其更高效。它们将输入门和遗忘门合并为一个单独的门。

GRUs 有两个门,一个更新门和一个重置门。它们的值通过 BPTT 训练来理解从当前输入和给定历史信息中应该记住或忘记哪些信息。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/10417a969da0ebb151cb70ed3fb4fe21.png

GRU 图,作者绘制。

让我们一步步来分解它。

更新门

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b1bd8f4eff2f4e05f96aa50190a63d29.png

作者在 LaTeX 中的更新门方程。

更新门负责决定从 h{t-1}_ 中哪些过去的信息对当前计算和未来是有帮助的。在这种情况下,我们有:

  • _zt: 时间步 t 的更新门。

  • 𝜎: sigmoid 激活函数。

  • W_z:权重矩阵。

  • h{t-1}_:前一个时间步的隐藏状态。

  • X_t:新的输入。

  • b_z:偏置项。

注意到这个方程基本上与常规循环单元的方程相同。区别在于权重矩阵学习的内容。对于更新门,它们将学习由于网络设置方式而需要记住的内容。所以,这一切都归结于网络的权重!

重置门

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e8182e8dbc92c52e823208102fa95c1f.png

作者在 LaTeX 中的重置门方程。

重置门与更新门相反,它决定我们应该忘记多少从 h{t-1}_ 中的过去信息。在这里我们有:

  • _rt: 时间步 t 的重置门。

  • 𝜎: sigmoid 激活函数。

  • W_r:权重矩阵。

  • h{t-1}_:前一个时间步的隐藏状态。

  • X_t:新的输入。

  • b_r:偏置项。

再次强调,这个方程与更新门相同,但权重矩阵的学习会影响这个门在网络中的操作方式。

候选隐藏状态

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/1d650298577a1fe9aa3c26a92206e328.png

作者在 LaTeX 中的候选隐藏状态方程。

这部分计算我们可以传递给新隐藏状态的可能的新信息。它是通过构建一个中间值,称为候选隐藏状态,_ht***,来做到这一点的。

重置门,r_t*,和前一个隐藏状态,h_{t-1}_,之间的哈达玛积决定了从先前知识中应该去除什么。

最终隐藏状态输出

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/8364942cba5231f0246bb918aad75ae1.png

作者在 LaTeX 中的最终隐藏状态方程。

最后的部分是计算最终的隐藏状态,h_t。这是通过使用更新门,z_t*,来确定我们应该从候选隐藏状态,ht***,中获取哪些新信息,以及我们应该保留哪些来自先前信息,h{t-1}_。

如果 z_t 为 0,我们忽略所有新信息,新的隐藏状态就等于旧的。然而,如果 z_t_ 为 1,那么我们就从新的候选隐藏状态中获取所有信息。

它将介于新旧隐藏状态之间,并组合新旧隐藏状态。

优势

GRU 相比 LSTM 具有以下优势

  • 结构更简单。

  • 更容易训练。

  • 门的数量和参数更少。

然而,它们的性能往往比 LSTM 差,因为它们难以保留长期依赖关系。

摘要及进一步思考

本文讨论了门控循环单元(GRUs)以及它们如何克服常规 RNN 所遭受的梯度消失问题。它们通过分配门控,使网络能够记住过去和当前时间步的关键信息,从而跳过不必要的连接。它们与 LSTMs 非常相似,但参数更少,这使得它们在计算上更高效。

另一件事!

我提供 1:1 辅导通话,我们可以讨论你需要的一切——无论是项目、职业建议,还是只是确定你的下一步。我在这里帮助你前进!

[1:1 Mentoring Call with Egor Howell]

职业指导、求职建议、项目帮助、简历审查topmate.io](https://topmate.io/egorhowell/1203300)

与我联系

参考文献

### 原理 门控循环单元GRU)是一种用于处理序列数据的递归神经网络(RNN)变体,通过引入门控机制解决传统RNN在处理长序列时的梯度消失问题。其核心思想是在保留LSTM长程记忆能力的前提下,简化网络结构。具体做法是合并LSTM的输入门和遗忘门为更新门,去除细胞状态,直接通过隐藏状态传递信息。 GRU的核心组件包括两个门和候选状态: - **更新门**($z_t$):控制历史信息与当前信息的融合比例,公式为$z_t = \sigma(W_z \cdot [h_{t - 1}, x_t] + b_z)$,其中$\sigma$是Sigmoid函数,输出0 - 1间的保留比例[^2]。 - **重置门**($r_t$):决定忽略多少历史信息来生成候选状态,公式为$r_t = \sigma(W_r \cdot [h_{t - 1}, x_t] + b_r)$ [^2]。 - **候选隐藏状态**($\tilde{h}_t$):包含当前输入与部分历史信息的中间状态,公式为$\tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t - 1}, x_t] + b)$,其中$\odot$是Hadamard积,用于控制历史信息流入量 [^2]。 最终隐藏状态$h_t$的计算公式为$h_t = (1 - z_t) \odot h_{t - 1} + z_t \odot \tilde{h}_t$ [^2]。 ### 应用 - **自然语言处理**:在机器翻译、文本生成、情感分析等任务中,GRU可以处理输入的序列信息,学习到文本中的语义和语法结构,从而进行准确的翻译、生成合理的文本或判断情感倾向。 - **语音识别**:处理语音信号的时序特征,将语音转换为文本。 - **时间序列预测**:如股票价格预测、天气预测等,通过学习时间序列数据的模式和趋势,进行未来值的预测。 ### 特点 - **结构简化**:合并了LSTM的输入门和遗忘门,去除了细胞状态,直接通过隐藏状态传递信息,使得网络结构更加简洁 [^2]。 - **参数减少**:参数数量比LSTM减少1/3,训练速度提升20 - 30%,能够在一定程度上提高训练效率 [^2]。 - **处理长时依赖**:和LSTM一样,GRU能够有效处理序列数据中的长时依赖问题,避免了传统RNN的梯度消失问题 [^1]。 以下是一个使用PyTorch实现简单GRU模型的代码示例: ```python import torch import torch.nn as nn class SimpleGRU(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super(SimpleGRU, self).__init__() self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): out, _ = self.gru(x) out = self.fc(out[:, -1, :]) return out # 示例参数 input_size = 10 hidden_size = 20 num_layers = 1 output_size = 1 # 创建模型实例 model = SimpleGRU(input_size, hidden_size, num_layers, output_size) # 示例输入 batch_size = 32 seq_length = 5 input_tensor = torch.randn(batch_size, seq_length, input_size) # 前向传播 output = model(input_tensor) print(output.shape) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值