使用代码解释GRU的原理

背景

        搜索GRU的结构图,现在基本都是这张图,理解起来的确有点吃力,还是得用代码跑一遍。

简单的GRUCell

pytorch中的GRUCell的计算公式是这样的,我们先不用管公式,跑个简单的demo看看


 

简单DEMO

通过以下代码我们可以看到上次的隐藏状态和当前的隐藏状态

input_size = 1
hidden_size = 1
batch_size = 1
gru = nn.GRUCell(input_size, hidden_size)
input = torch.randn(6, batch_size, input_size)
hx = torch.randn(batch_size, hidden_size)
hx, gru(input[0], hx)

提取模型参数

weight_ih, weight_hh, bias_ih, bias_hh = list(gru.parameters())
weight_ih, weight_hh, bias_ih, bias_hh

使用函数模拟

r - 重置门
                                r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr})
  • 输出范围为[0,1], 基本计算公式和RNNCell一样
weight_ir_l0 = weight_ih[0]
bias_ir_l0 = bias_ih[0]
weight_hr_l0 = weight_hh[0]
bias_hr_l0 = bias_hh[0]
r = torch.sigmoid(weight_ir_l0 * input[0] + bias_ir_l0  + weight_hr_l0 * hx + bias_hr_l0)
r

z-更新门

                                z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz})

  • 输出范围为[0,1], 和r重置门的计算方式一样,只不过用途不同
weight_iz_l0 = weight_ih[1]
bias_iz_l0 = bias_ih[1]
weight_hz_l0 = weight_hh[1]
bias_hz_l0 = bias_hh[1]
z = torch.sigmoid(weight_iz_l0 * input[0] + bias_iz_l0  + weight_hz_l0 * hx + bias_hz_l0)
z

n-候选隐藏状态,即当前的隐藏状态

        ​​​​​​​        ​​​​​​​        n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn}))

weight_in_l0 = weight_ih[2]
bias_in_l0 = bias_ih[2]
weight_hn_l0 = weight_hh[2]
bias_hn_l0 = bias_hh[2]
n = torch.tanh(weight_in_l0 * input[0] + bias_in_l0  + torch.matmul(r, (weight_hn_l0 * hx + bias_hn_l0)))
n

h‘-计算,最终的隐藏状态

        ​​​​​​​        ​​​​​​​        ​​​​​​​        h' = (1 - z) \odot n + z \odot h


h1 = torch.matmul(1 - z, n) + torch.matmul(z, hx)
hx, h1

分析

        不出意外,我们得到的记过和h1和上面GRUCell的运行结果是一样的。接下来我们分析下为什么要这么做?

        从h‘的公式可以看出,(1-z)、z为[0,1]的权重系数,其大小决定最终状态h‘的组成占比。

  • h 我们理解为长期记忆(也就是上次的隐藏状态), n时是通过RNNCell方式计算出来的当前的隐藏状态, 也就是短期依赖,因此通过z(更新门)就控制了长短期依赖的比重。
  • 不知道有没有细心的同学发现,z计算的时候用到了重置门r, r 控制了上次隐藏状态h的比例,那是不是和最终隐藏状态h'的计算有点冲突了,两个公式里面都有控制h的权重?

        我刚开始看的时候也有这个疑问,后来去原论文中看了看 有这样一段解释

In this formulation, when the reset gate is close to 0, the hidden state is forced to ignore the pre- vious hidden state and reset with the current input

only. This effectively allows the hidden state to drop any information that is found to be irrelevant later in the future, thus, allowing a more compact representation.

On the other hand, the update gate controls how much information from the previous hidden state will carry over to the current hidden state. This acts similarly to the memory cell in the LSTM network and helps the RNN to remember long- term information. Furthermore, this may be con- sidered an adaptive variant of a leaky-integration unit (Bengio et al., 2013).

As each hidden unit has separate reset and up- date gates, each hidden unit will learn to capture dependencies over different time scales. Those units that learn to capture short-term dependencies will tend to have reset gates that are frequently ac- tive, but those that capture longer-term dependen- cies will have update gates that are mostly active.

In our preliminary experiments, we found that it is crucial to use this new unit with gating units. We were not able to get meaningful result with an oft-used tanh unit without any gating.

总结下来就是r和z是两个独立的单元,独立学习,r控制是否丢掉长期的依赖信息,只保留短期依赖,z控制保留长期依赖的比例,两者并不冲突。就类似于 y = wx+b 本身是有w和b控制,只有一个神经元,我们将其优化为y = (wx1 + b1)x2 + b, 其中w的成分变得更复杂了(变成两个神经元),这样会带来更多的灵活性,不知道这样理解对不对?

网络训练是黑盒,我们可能会碰到需要部分神经元丢掉部分不相关的长期依赖,保留当前输入的相关的短期依赖,同时另一个神经元需要保留相关的长期依赖,两者都会对最终的结果最初贡献。

总结

本文主要用简单方便理解的数学公式复现了GRU黑盒,并借此来加深理解,欢迎提出建议。

### GRU神经网络原理图及结构详解 #### 什么是GRU? 门控循环单元(Gated Recurrent Unit, GRU)是一种用于处理序列数据的循环神经网络(Recurrent Neural Network, RNN),其设计旨在解决传统RNN中的梯度消失问题以及长期依赖问题。通过引入更新门(Update Gate)和重置门(Reset Gate),GRU能够更有效地控制信息流,从而更好地捕捉时间序列中的复杂模式[^3]。 --- #### GRU的核心组件 GRU的主要组成部分包括以下几个方面: 1. **更新门(Update Gate)** 更新门决定了前一时刻的状态有多少会被保留到当前状态中。它的计算方式如下: \[ z_t = \sigma(W_z \cdot [h_{t-1}, x_t]) \] 其中 \(z_t\) 是更新门的输出,\(W_z\) 是权重矩阵,\(\sigma\) 表示sigmoid激活函数,\[h_{t-1}\] 和 \(x_t\) 分别表示上一时刻隐藏状态和当前输入。 2. **重置门(Reset Gate)** 重置门的作用是决定是否忽略之前的隐藏状态,在某些情况下可以减少不必要的历史信息干扰。其公式为: \[ r_t = \sigma(W_r \cdot [h_{t-1}, x_t]) \] 3. **候选隐藏状态(Candidate Hidden State)** 候选隐藏状态是一个临时变量,用于存储潜在的新状态值。它基于当前输入和经过重置门调整后的先前状态来计算: \[ \tilde{h}_t = \text{tanh}(W_h \cdot [r_t * h_{t-1}, x_t]) \] 4. **最终隐藏状态(Final Hidden State)** 隐藏状态由更新门调节旧状态与新状态之间的比例得到: \[ h_t = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_t \] 以上四个部分共同构成了GRU的工作流程。 --- #### GRU原理图解 以下是GRU的典型结构图及其各模块的功能描述: ![GRU Structure](https://upload.wikimedia.org/wikipedia/commons/thumb/8/8f/Gated_Recurrent_Unit.svg/1920px-Gated_Recurrent_Unit.svg.png) - 图中标注了三个主要操作区域: - 左侧展示了如何利用更新门 (\(z_t\)) 来融合过去的状态; - 中间显示了重置门 (\(r_t\)) 对于遗忘机制的影响; - 右边则呈现了新的候选状态是如何生成并与现有状态相结合的过程。 这种紧凑而高效的架构使得GRU相较于传统的LSTM更加轻量化,同时也保持了较强的表达能力[^1]。 --- #### Python实现代码示例 下面提供了一个简单的GRU层定义代码片段供参考: ```python import torch.nn as nn class SimpleGRUNet(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleGRUNet, self).__init__() self.hidden_size = hidden_size self.gru = nn.GRU(input_size, hidden_size) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): out, _ = self.gru(x) out = self.fc(out[-1]) # 使用最后一个时间步的结果 return out ``` 此代码创建了一种基本形式的GRU模型框架,适用于多种序列预测任务场景下快速搭建原型测试环境[^2]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值