RNN pytorch
在PyTorch中有两种构造RNN的方式:一种是构造RNNCell,然后自己写循环;一种是直接构造RNN。
第一种:构造RNNCell,然后自己写循环
-
构造RNNCell
需要两个参数:input_size和hidden_size。
cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)
-
使用RNNCell
hidden = cell(input, hidden)
调用时,将当前输入input和上一层的输出hidden。这些数据要满足以下条件:
- input:batch_size*input_size
- 输入的hidden: batch_size*hidden_size
- 输出的hidden: batch_size*hidden_size
注意:dataset的形式是seg_len*batch_size*hidden_size,序列长度放在第一个参数。因为每次是从dataset中拿出一个序列,即batch_size*input_size。
-
示例代码(看维度)
import torch batch_size = 1 seq_len = 3 input_size = 4 hidden_size = 2 cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size) dataset = torch.randn(seq_len, batch_size, input_size) hidden = torch.zeros(batch_size, hidden_size) for idx, input in enumerate(dataset): print('=' * 20, idx, '='*20) print('Input size:', input.shape) hidden = cell(input, hidden) print('output size:', hidden.shape) print(hidden)
第二种:直接使用RNN
-
注意几点:
- 增加参数num_layers,表示有多少层(上图所示只有1层)。
- 上图所示为1层,一层包含若干个RNN Cell。
- inputs指的是 x1,x2,x3,…,xNx_1, x_2, x_3, …, x_Nx1,x2,x3,<