GRU (Gated Recurrent Unit)也称门控循环单元结构,它也是传统RNN的变体,同LSTM一样能
够有效捕捉长序列之间的语义关联缓解、梯度消失或爆炸现象。同时它的结构和计算要比LSTM
更简单它的核心结构可以分为两个部分去解析:更新门和重置门。使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。
一、总体结构
如图是一个GRU单元的内部构造,图中的和
分别表示更新门和重置门。GRU将LSTM中的输入门和遗忘门合二为一,称为更新门,更新门用于控制前一时刻的状态信息被带入到当前状态中的程度,更新门的值越大说明前一时刻的状态信息带入越多。重置门控制前一状态有多少信息被写入到当前的候选集
上,重置门越小,前一状态的信息被写入的越少,控制要遗忘多少过去的信息。
二、内部结构分析
和之前分析过的LSTM中的门控一样,首先计算更新门和重置门的门值,分别是和
,计算方法就是使用
与
拼接进行线性变换,再经过sigmoid激活。
之后更新门门值作用在了上,代表控制上一时间步传来的信息有多少可以被利用。接着就是使用这个更新后的
进行基本的RNN讲算,即与
拼接进行线性变化,经过
激活,得到新的
。
最后重置门的门值会作用在新的,而
门值会作用在
上,随后将两者的结果相加,得到最终的隐含状态输出
,这个过程意味着重置门有能力重置之前所有的计算,当门值趋于1时,输出就是新的
,而当门值趋于0时,输出就是上一时间步的
。
三、PyTorch中实现GRU
在PyTorch中torch.nn包中已经实现了GRU,我们直接使用即可。总体来说和传统RNN还有LSTM并没有特别大的区别。
import torch
from torch import nn
# 实例化GRU对象
# 第一个参数:input_szie(输入张量x的维度)
# 第二个参数:hidden_size(隐藏层维度,隐藏神经元的数量)
# 第三个参数:num_layers(隐藏层的层数)
gru=nn.GRU(5,6,2)
# 输入也是seq_len,batch_size,input_size
# 隐藏为num_layers,batch_size,hidden_size
input=torch.randn(1,3,5)
h0=torch.randn(2,3,6)
output,hn=gru(input,h0)
最终output和hn的shape值可以打印出来查看
#output.shap num_layers*batch_size*hidden_size
torch.Size([1, 3, 6])
#hn.shape和h0的shape一样
torch.Size([2, 3, 6])
四、总结
GRU的优势:GRU和LSTM作用相同, 在捕捉长序列语义关联时,能有效抑制梯度消失或爆炸,效果都优于传统RNN且计算复杂度相比LSTM要小。
GRU的缺点:GRU仍然不能完全解决梯度消失问题,同时其作用RNN的变体,有着RNN结构本身的一大弊端,即不可并行计算,这在数据量和模型体量逐步增大的未来,是RNN发展的关键瓶颈。