PyTorch中用于SimpleRNN的方法主要是nn.RNN及nn.RNNCell。两者的区别是前者输入一个序列,而后者输入单个时间步,必须我们手动完成时间步之间的操作。前者比较简单,为了能更深入地了解SimpleRNN的运作过程,我决定用两种方法都呈现一下。
———————————————————————————————————————
from torch import nn
nn.RNNCell(input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = 'tanh')
这是初始化RNNCell需要的一些参数。官方文档中给出了详细的解释:

从上图还可以看到PyTorch官方文档中给出的公式。但我个人觉得,这里可以把两个偏置合为一个偏置,事实上在花书中也确实是这么给公式的:

RNN详细的来源、发展过程、各种变体大家感兴趣的可以去看相关专著或blog,这里不赘述了,直接看一个例子吧~

x x x表示待输入序列,序列长度是4,batch_size是1,特征数是2;
h h h表示隐藏单元,初始状态 h ( 0 ) h^{(0)} h(0)是0,形状是(1,1);
o o o表示输出,序列长度是4,为了简化计算,我用的是relu作为激活函数;
W W W是隐藏单元之间的连接权,形状是(1,1),值是[[2]],PyTorch中可以通过rnn_cell.weight__hh.data访问及设置;
U U U是输入与隐藏单元之间的连接权,形状是(1,2),值是[[-1,3]]PyTorch中可以通过rnn_cell.weight__ih.data访问及设置;
为了简化计算,这里不设置偏置,故bias是None。
人工计算过程如下:

注意到h都是正数,激活函数是ReLU,所以这里激活与否不会影响最终

本文通过实例详细解析了PyTorch中SimpleRNN的两种实现方式:使用nn.RNNCell和nn.RNN。介绍了如何设置权重矩阵,并通过手算验证了PyTorch实现的正确性。
最低0.47元/天 解锁文章
1319

被折叠的 条评论
为什么被折叠?



