RNN pytorch
在PyTorch中有两种构造RNN的方式:一种是构造RNNCell,然后自己写循环;一种是直接构造RNN。
第一种:构造RNNCell,然后自己写循环
-
构造RNNCell

需要两个参数:input_size和hidden_size。
cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size) -
使用RNNCell
hidden = cell(input, hidden)调用时,将当前输入input和上一层的输出hidden。这些数据要满足以下条件:
- input:batch_size*input_size
- 输入的hidden: batch_size*hidden_size
- 输出的hidden: batch_size*hidden_size
注意:dataset的形式是seg_len*batch_size*hidden_size,序列长度放在第一个参数。因为每次是从dataset中拿出一个序列,即batch_size*input_size。

-
示例代码(看维度)
import torch batch_size = 1 seq_len = 3 input_size = 4 hidden_size = 2 cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size) dataset = torch.randn(seq_len, batch_size, input_size) hidden = torch.zeros(batch_size, hidden_size) for idx, input in enumerate(dataset): print('=' * 20, idx, '='*20) print('Input size:', input.shape) hidden = cell(input, hidden) print('output size:', hidden.shape) print(hidden)
第二种:直接使用RNN
-
注意几点:

- 增加参数num_layers,表示有多少层(上图所示只有1层)。
- 上图所示为1层,一层包含若干个RNN Cell。
- inputs指的是 x1,x2,x3,…,xNx_1, x_2, x_3, …, x_Nx1,x2,x3,<

本文介绍了在PyTorch中实现RNN的两种方法,包括自定义RNNCell和直接使用RNN模块。通过示例代码详细阐述了每种方法的使用,并展示了如何训练一个模型完成“hello”到“ohlol”的序列到序列学习任务。文章强调了输入和隐藏状态的维度管理以及参数设置的重要性。
最低0.47元/天 解锁文章
1501

被折叠的 条评论
为什么被折叠?



